diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 1a9eefc47ae17b..cc6ff752c7701e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -326,6 +326,8 @@ title: CamemBERT - local: model_doc/canine title: CANINE + - local: model_doc/chameleon + title: chameleon - local: model_doc/codegen title: CodeGen - local: model_doc/code_llama diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 3691bff960e3a2..92cbdd44d7c0ea 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -88,6 +88,7 @@ Flax), PyTorch, and/or TensorFlow. | [ByT5](model_doc/byt5) | ✅ | ✅ | ✅ | | [CamemBERT](model_doc/camembert) | ✅ | ✅ | ❌ | | [CANINE](model_doc/canine) | ✅ | ❌ | ❌ | +| [Chameleon](model_doc/chameleon) | ✅ | ❌ | ❌ | | [Chinese-CLIP](model_doc/chinese_clip) | ✅ | ❌ | ❌ | | [CLAP](model_doc/clap) | ✅ | ❌ | ❌ | | [CLIP](model_doc/clip) | ✅ | ✅ | ✅ | diff --git a/docs/source/en/model_doc/chameleon.md b/docs/source/en/model_doc/chameleon.md new file mode 100644 index 00000000000000..e2a0012ba97f2c --- /dev/null +++ b/docs/source/en/model_doc/chameleon.md @@ -0,0 +1,189 @@ + + +# Chameleon + +## Overview + +The Chameleon model was proposed in [Chameleon: Mixed-Modal Early-Fusion Foundation Models +](https://arxiv.org/abs/2405.09818v1) by META AI Chameleon Team. Chameleon is a Vision-Language Model that use vector quantization to tokenize images which enables the model to generate multimodal output. The model takes images and texts as input, including an interleaved format, and generates textual response. Image generation module is not released yet. + + +The abstract from the paper is the following: + +*We present Chameleon, a family of early-fusion token-based mixed-modal models capable of understanding and generating images and text in any arbitrary sequence. We outline a stable training +approach from inception, an alignment recipe, and an architectural parameterization tailored for the +early-fusion, token-based, mixed-modal setting. The models are evaluated on a comprehensive range +of tasks, including visual question answering, image captioning, text generation, image generation, and +long-form mixed modal generation. Chameleon demonstrates broad and general capabilities, including +state-of-the-art performance in image captioning tasks, outperforms Llama-2 in text-only tasks while +being competitive with models such as Mixtral 8x7B and Gemini-Pro, and performs non-trivial image +generation, all in a single model. It also matches or exceeds the performance of much larger models, +including Gemini Pro and GPT-4V, according to human judgments on a new long-form mixed-modal +generation evaluation, where either the prompt or outputs contain mixed sequences of both images and +text. Chameleon marks a significant step forward in a unified modeling of full multimodal documents* + + + + + Chameleon incorporates a vector quantizer module to transform images into discrete tokens. That also enables image geenration using an auto-regressive transformer. Taken from the original paper. + +This model was contributed by [joaogante](https://huggingface.co/joaogante) and [RaushanTurganbay](https://huggingface.co/RaushanTurganbay). +The original code can be found [here](https://github.com/facebookresearch/chameleon). + + +## Usage tips + +- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to set `processor.tokenizer.padding_side = "left"` before generating. + +- Note that Chameleon was tuned for safety alignment. If the model is refusing to answer, consider asking a more concrete question, instead of an open question. + +- Chameleon generates in chat format which means that the generated text will always be the "assistant's turn". You can enable a text completion generation by passing `return_for_text_completion=True` when calling the processor. + +> [!NOTE] +> Chameleon implementation in Transformers uses a special image token to indicate where to merge image embeddings. For special image token we didn't add a new one but used one of the reserved tokens: ``. + +## Usage example + +### Single image inference + +Here's how to load the model and perform inference in half-precision (`torch.float16`): + +```python +from transformers import ChameleonProcessor, ChameleonForCausalLM +import torch +from PIL import Image +import requests + +processor = ChameleonProcessor.from_pretrained("meta-chameleon") +model = ChameleonForCausalLM.from_pretrained("meta-chameleon", torch_dtype=torch.float16, device_map="auto") + +# prepare image and text prompt +url = "https://bjiujitsu.com/wp-content/uploads/2021/01/jiu_jitsu_belt_white_1.jpg" +image = Image.open(requests.get(url, stream=True).raw) +prompt = "What color is the belt in this image?" + +inputs = processor(prompt, image, return_tensors="pt").to(model.device) + +# autoregressively complete prompt +output = model.generate(**inputs, max_new_tokens=50) +print(processor.decode(output[0], skip_special_tokens=True)) +``` + +### Multi image inference + +Chameleon can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). Here is how you can do it: + +```python +from transformers import ChameleonProcessor, ChameleonForCausalLM +import torch +from PIL import Image +import requests + +processor = ChameleonProcessor.from_pretrained("meta-chameleon") +model = ChameleonForCausalLM.from_pretrained("meta-chameleon", torch_dtype=torch.float16, device_map="auto") + +# Get three different images +url = "https://www.ilankelman.org/stopsigns/australia.jpg" +image_stop = Image.open(requests.get(url, stream=True).raw) + +url = "http://images.cocodataset.org/val2017/000000039769.jpg" +image_cats = Image.open(requests.get(url, stream=True).raw) + +url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg" +image_snowman = Image.open(requests.get(url, stream=True).raw) + +# Prepare a batched prompt, where the first one is a multi-image prompt and the second is not +prompts = [ + "What do these images have in common?", + "What is shown in this image?" +] + +# We can simply feed images in the order they have to be used in the text prompt +# Each "" token uses one image leaving the next for the subsequent "" tokens +inputs = processor(text=prompts, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(model.device) + +# Generate +generate_ids = model.generate(**inputs, max_new_tokens=50) +processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) +``` + +## Model optimization + +### Quantization using Bitsandbytes + +The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes` and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with: + +```python +from transformers import ChameleonForCausalLM, BitsAndBytesConfig + +# specify how to quantize the model +quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, +) + +model = ChameleonForCausalLM.from_pretrained("meta-chameleon", quantization_config=quantization_config, device_map="auto") +``` + +### Use Flash-Attention 2 and SDPA to further speed-up generation + +The models supports both, Flash-Attention 2 and PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) which can be enables for optimization. SDPA is the default options when you load the model, If you want to switch for Flash Attention 2, first make sure to install flash-attn. Refer to the [original repository](https://github.com/Dao-AILab/flash-attention) regarding that package installation. Simply change the snippet above with: + +```python +from transformers import ChameleonForCausalLM + +model = ChameleonForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + attn_implementation="flash_attention_2" +).to(0) +``` + +## ChameleonConfig + +[[autodoc]] ChameleonConfig + +## ChameleonVQVAEConfig + +[[autodoc]] ChameleonVQVAEConfig + +## ChameleonProcessor + +[[autodoc]] ChameleonProcessor + +## ChameleonImageProcessor + +[[autodoc]] ChameleonImageProcessor + - preprocess + +## ChameleonVQVAE + +[[autodoc]] ChameleonVQVAE + - forward + +## ChameleonModel + +[[autodoc]] ChameleonModel + - forward + +## ChameleonForCausalLM + +[[autodoc]] ChameleonForCausalLM + - forward diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index b18e737ff97361..396c7bc2a9db06 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -39,6 +39,7 @@ FlashAttention-2 is experimental and may change considerably in future versions. FlashAttention-2 is currently supported for the following architectures: * [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel) * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) +* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon) * [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel) * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) * [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel) @@ -198,6 +199,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel) * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) * [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel) +* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon) * [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel) * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) * [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 344da0d3dc8ef3..6d52b30e1ca541 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -249,6 +249,11 @@ "CanineConfig", "CanineTokenizer", ], + "models.chameleon": [ + "ChameleonConfig", + "ChameleonProcessor", + "ChameleonVQVAEConfig", + ], "models.chinese_clip": [ "ChineseCLIPConfig", "ChineseCLIPProcessor", @@ -1125,6 +1130,7 @@ _import_structure["models.bit"].extend(["BitImageProcessor"]) _import_structure["models.blip"].extend(["BlipImageProcessor"]) _import_structure["models.bridgetower"].append("BridgeTowerImageProcessor") + _import_structure["models.chameleon"].append("ChameleonImageProcessor") _import_structure["models.chinese_clip"].extend(["ChineseCLIPFeatureExtractor", "ChineseCLIPImageProcessor"]) _import_structure["models.clip"].extend(["CLIPFeatureExtractor", "CLIPImageProcessor"]) _import_structure["models.conditional_detr"].extend( @@ -1608,6 +1614,15 @@ "load_tf_weights_in_canine", ] ) + _import_structure["models.chameleon"].extend( + [ + "ChameleonForCausalLM", + "ChameleonModel", + "ChameleonPreTrainedModel", + "ChameleonProcessor", + "ChameleonVQVAE", + ] + ) _import_structure["models.chinese_clip"].extend( [ "ChineseCLIPModel", @@ -4890,6 +4905,11 @@ CanineConfig, CanineTokenizer, ) + from .models.chameleon import ( + ChameleonConfig, + ChameleonProcessor, + ChameleonVQVAEConfig, + ) from .models.chinese_clip import ( ChineseCLIPConfig, ChineseCLIPProcessor, @@ -5807,6 +5827,7 @@ from .models.bit import BitImageProcessor from .models.blip import BlipImageProcessor from .models.bridgetower import BridgeTowerImageProcessor + from .models.chameleon import ChameleonImageProcessor from .models.chinese_clip import ( ChineseCLIPFeatureExtractor, ChineseCLIPImageProcessor, @@ -6254,6 +6275,13 @@ CaninePreTrainedModel, load_tf_weights_in_canine, ) + from .models.chameleon import ( + ChameleonForCausalLM, + ChameleonModel, + ChameleonPreTrainedModel, + ChameleonProcessor, + ChameleonVQVAE, + ) from .models.chinese_clip import ( ChineseCLIPModel, ChineseCLIPPreTrainedModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index cd3cafa9620896..cc1e41b3fc4076 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -42,6 +42,7 @@ byt5, camembert, canine, + chameleon, chinese_clip, clap, clip, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index df73312c74b969..512c1eaaf5e01a 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -55,6 +55,7 @@ ("bros", "BrosConfig"), ("camembert", "CamembertConfig"), ("canine", "CanineConfig"), + ("chameleon", "ChameleonConfig"), ("chinese_clip", "ChineseCLIPConfig"), ("chinese_clip_vision_model", "ChineseCLIPVisionConfig"), ("clap", "ClapConfig"), @@ -329,6 +330,7 @@ ("byt5", "ByT5"), ("camembert", "CamemBERT"), ("canine", "CANINE"), + ("chameleon", "Chameleon"), ("chinese_clip", "Chinese-CLIP"), ("chinese_clip_vision_model", "ChineseCLIPVisionModel"), ("clap", "CLAP"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 388d4d46f0fe26..8bfc61b9bea349 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -59,6 +59,7 @@ ("blip", ("BlipImageProcessor",)), ("blip-2", ("BlipImageProcessor",)), ("bridgetower", ("BridgeTowerImageProcessor",)), + ("chameleon", ("ChameleonImageProcessor",)), ("chinese_clip", ("ChineseCLIPImageProcessor",)), ("clip", ("CLIPImageProcessor",)), ("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index bf46276def01b5..b99cbe19bbd644 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -55,6 +55,7 @@ ("bros", "BrosModel"), ("camembert", "CamembertModel"), ("canine", "CanineModel"), + ("chameleon", "ChameleonModel"), ("chinese_clip", "ChineseCLIPModel"), ("chinese_clip_vision_model", "ChineseCLIPVisionModel"), ("clap", "ClapModel"), @@ -445,6 +446,7 @@ ("blenderbot-small", "BlenderbotSmallForCausalLM"), ("bloom", "BloomForCausalLM"), ("camembert", "CamembertForCausalLM"), + ("chameleon", "ChameleonForCausalLM"), ("code_llama", "LlamaForCausalLM"), ("codegen", "CodeGenForCausalLM"), ("cohere", "CohereForCausalLM"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 631fee8f8cc444..1ab136a1e74ca7 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -51,6 +51,7 @@ ("blip", "BlipProcessor"), ("blip-2", "Blip2Processor"), ("bridgetower", "BridgeTowerProcessor"), + ("chameleon", "ChameleonProcessor"), ("chinese_clip", "ChineseCLIPProcessor"), ("clap", "ClapProcessor"), ("clip", "CLIPProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index dddab5379f5657..55ea0794d04c7e 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -107,6 +107,13 @@ ), ), ("canine", ("CanineTokenizer", None)), + ( + "chameleon", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), ("chinese_clip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ( "clap", diff --git a/src/transformers/models/chameleon/__init__.py b/src/transformers/models/chameleon/__init__.py new file mode 100644 index 00000000000000..71e40a5da4afa3 --- /dev/null +++ b/src/transformers/models/chameleon/__init__.py @@ -0,0 +1,83 @@ +# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_sentencepiece_available, + is_tokenizers_available, + is_torch_available, + is_vision_available, +) + + +_import_structure = { + "configuration_chameleon": ["ChameleonConfig", "ChameleonVQVAEConfig"], + "processing_chameleon": ["ChameleonProcessor"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_chameleon"] = [ + "ChameleonForCausalLM", + "ChameleonModel", + "ChameleonPreTrainedModel", + "ChameleonVQVAE", + ] + +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["image_processing_chameleon"] = ["ChameleonImageProcessor"] + + +if TYPE_CHECKING: + from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig + from .processing_chameleon import ChameleonProcessor + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_chameleon import ( + ChameleonForCausalLM, + ChameleonModel, + ChameleonPreTrainedModel, + ChameleonVQVAE, + ) + + try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .image_processing_chameleon import ChameleonImageProcessor + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/chameleon/configuration_chameleon.py b/src/transformers/models/chameleon/configuration_chameleon.py new file mode 100644 index 00000000000000..67de37f2d01b2c --- /dev/null +++ b/src/transformers/models/chameleon/configuration_chameleon.py @@ -0,0 +1,276 @@ +# coding=utf-8 +# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""chameleon model configuration""" + +from typing import List + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ChameleonVQVAEConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ChameleonVQModel`]. It is used to instantiate a + `ChameleonVQModel` according to the specified arguments, defining the model architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. Instantiating a + configuration with the defaults will yield a similar configuration to the VQModel of the + [meta/chameleon-7B](https://huggingface.co/meta/chameleon-7B). + + Args: + embed_dim (`int`, *optional*, defaults to 256): + Dimensionality of each embedding vector. + num_embeddings (`int`, *optional*, defaults to 8192): + Number of codebook embeddings. + double_latent (`bool`, *optional*, defaults to `False`): + Whether to use double z channels. + latent_channels (`int`, *optional*, defaults to 256): + Number of channels for the latent space. + resolution (`int`, *optional*, defaults to 512): + Resolution of the input images. + in_channels (`int`, *optional*, defaults to 3): + Number of input channels. + base_channels (`int`, *optional*, defaults to 128): + Base channel count. + channel_multiplier (`List[int]`, *optional*, defaults to `[1, 1, 2, 2, 4]`): + Channel multipliers for each resolution. + num_res_blocks (`int`, *optional*, defaults to 2): + Number of residual blocks. + attn_resolutions (`List[int]`, *optional*): + Resolutions to apply attention. + dropout (`float`, *optional*, defaults to 0.0): + Dropout rate. + attn_type (`str`, *optional*, defaults to `"vanilla"`): + Attention type used in VQ-GAN encoder. Can be "vanilla" or None. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + """ + + model_type = "chameleon_vqgan" + + def __init__( + self, + embed_dim: int = 256, + num_embeddings: int = 8192, + double_latent: bool = False, + latent_channels: int = 256, + resolution: int = 512, + in_channels: int = 3, + base_channels: int = 128, + channel_multiplier: List[int] = [1, 1, 2, 2, 4], + num_res_blocks: int = 2, + attn_resolutions: List[int] = None, + dropout: float = 0.0, + attn_type: str = "vanilla", + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + self.num_embeddings = num_embeddings + self.double_latent = double_latent + self.latent_channels = latent_channels + self.resolution = resolution + self.in_channels = in_channels + self.base_channels = base_channels + self.channel_multiplier = channel_multiplier + self.num_res_blocks = num_res_blocks + self.attn_resolutions = attn_resolutions + self.dropout = dropout + self.attn_type = attn_type + self.initializer_range = initializer_range + + +class ChameleonConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ChameleonModel`]. It is used to instantiate a + chameleon model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the + [meta/chameleon-7B](https://huggingface.co/meta/chameleon-7B). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 65536): + Vocabulary size of the chameleon model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ChameleonModel`]; this includes text and image tokens. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. Chameleon supports up to 4096 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/Localchameleon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + model_parallel_size (`int`, *optional*, defaults to 1): + Number of shards used when training the model. This will be used in qk layernorm because the original Chameleon inference + doesn't do reduction in those layers and each rank has its own biases. + swin_norm (`bool`, *optional*, defaults to `False`): + Use Swin Transformer normalization. + vq_config (`dict`, *optional*): + ChameleonVQConfig instance containing the configuration for the VQ-VAE model. + vocabulary_map (`dict`, *optional*): + A dictionary containing the vocabulary map from the tokenizer. Used to obtain tokens from the image inputs. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + + + ```python + >>> from transformers import ChameleonModel, ChameleonConfig + + >>> # Initializing a chameleon chameleon-7b style configuration + >>> configuration = ChameleonConfig() + + >>> # Initializing a model from the chameleon-7b style configuration + >>> model = ChameleonModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "chameleon" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=65536, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + model_parallel_size=1, + swin_norm=False, + vq_config=None, + vocabulary_map=None, + mlp_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.mlp_bias = mlp_bias + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.model_parallel_size = model_parallel_size + self.swin_norm = swin_norm + + if vq_config is None: + vq_config = {} + logger.info("vq_config is None. initializing the ChameleonVQConfig with default values.") + + self.vq_config = ChameleonVQVAEConfig(**vq_config) + + self.vocabulary_map = vocabulary_map + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py b/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py new file mode 100644 index 00000000000000..2c1f5e89cb2405 --- /dev/null +++ b/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py @@ -0,0 +1,476 @@ +# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import json +import os + +import requests +import torch +import yaml +from accelerate import init_empty_weights +from PIL import Image + +from transformers import ( + ChameleonConfig, + ChameleonForCausalLM, + ChameleonImageProcessor, + ChameleonProcessor, +) + + +try: + from transformers import LlamaTokenizerFast +except ImportError: + raise ValueError( + "Chameleon conversion supports only FastTokenizer and LlamaTokenizerFast can't be imported! " + "Update your `tokenizers` library and re-run the tokenizer conversion." + ) + +""" +Sample usage: + +``` +python src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py \ + --input_dir /path/to/downloaded/chameleon/weights --model_size 7B --output_dir /output/path +``` + +Thereafter, models can be loaded via: + +```py +from transformers import ChameleonForCausalLM, LlamaTokenizer + +model = ChameleonForCausalLM.from_pretrained("/output/path") +tokenizer = LlamaTokenizer.from_pretrained("/output/path") +``` + +Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions +come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). +""" + +NUM_SHARDS = { + "7B": 1, + "30B": 4, +} + +VOCAB_SIZE = 65536 + + +def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): + return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) + + +def read_json(path): + with open(path, "r") as f: + return json.load(f) + + +def write_json(text, path): + with open(path, "w") as f: + json.dump(text, f) + + +def write_model(model_path, input_base_path, model_size, chameleon_version=1): + os.makedirs(model_path, exist_ok=True) + input_model_path = os.path.join(input_base_path, "models", model_size.lower()) + params_path = os.path.join(input_model_path, "params.json") + consolidate_params_path = os.path.join(input_model_path, "consolidate_params.json") + + params = read_json(params_path) + if os.path.isfile(consolidate_params_path): + params = {**params, **read_json(consolidate_params_path)} + num_shards = NUM_SHARDS[model_size] + model_parallel_size = params["model_parallel_size"] + params = params.get("model", params) + n_layers = params["n_layers"] + n_heads = params["n_heads"] + n_heads_per_shard = n_heads // num_shards + dim = params["dim"] + dims_per_head = dim // n_heads + base = params.get("rope_theta", 10000.0) + swin_norm = params["swin_norm"] + if base > 10000.0: + max_position_embeddings = 16384 + else: + # Depending on the Chameleon version, the default max_position_embeddings has different values. + if chameleon_version == 1: + max_position_embeddings = 4096 + else: + raise NotImplementedError( + f"Version {chameleon_version} of chameleon is not supported yet. " + "Current supported versions of chameleon are [1]." + ) + + if params.get("n_kv_heads", None) is not None: + num_key_value_heads = params["n_kv_heads"] # for GQA / MQA + num_local_key_value_heads = n_heads_per_shard // num_key_value_heads + key_value_dim = dim // num_key_value_heads + else: # compatibility with other checkpoints + num_key_value_heads = n_heads + num_local_key_value_heads = n_heads_per_shard + key_value_dim = dim + + print(f"Fetching all parameters from the checkpoint at {input_model_path}.") + # Load weights + if num_shards == 1: + # Not sharded + # (The sharded implementation would also work, but this is simpler.) + loaded = None + for possible_name in ["consolidated.pth", "consolidated.00.pth"]: + possible_path = os.path.join(input_model_path, possible_name) + if os.path.exists(possible_path): + loaded = torch.load(possible_path, map_location="cpu") + break + assert loaded is not None + else: + # Sharded + loaded = [ + torch.load(os.path.join(input_model_path, f"consolidated.{i:02d}.pth"), map_location="cpu") + for i in range(num_shards) + ] + + # permute for sliced rotary + def permute(w, n_heads, dim1=dim, dim2=dim): + return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + # Load weights to the state dict + state_dict = {} + for layer_i in range(n_layers): + if num_shards == 1: + # Unsharded + state_dict.update( + { + f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( + loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads + ), + f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( + loaded[f"layers.{layer_i}.attention.wk.weight"], + n_heads=num_key_value_heads, + dim1=key_value_dim, + ), + f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], + f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], + f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"], + f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"], + f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"], + f"model.layers.{layer_i}.input_layernorm.weight": loaded[ + f"layers.{layer_i}.attention_norm.weight" + ], + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[ + f"layers.{layer_i}.ffn_norm.weight" + ], + } + ) + # qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677) + state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = ( + loaded[f"layers.{layer_i}.attention.q_normalization.weight"] + .view(dims_per_head // 2, 2) + .t() + .reshape(1, -1) + .repeat_interleave(n_heads, 0) + ) + state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = ( + loaded[f"layers.{layer_i}.attention.q_normalization.bias"] + .view(dims_per_head // 2, 2) + .t() + .reshape(1, -1) + .repeat_interleave(n_heads, 0) + ) + state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = ( + loaded[f"layers.{layer_i}.attention.k_normalization.weight"] + .view(dims_per_head // 2, 2) + .t() + .reshape(1, -1) + .repeat_interleave(num_key_value_heads, 0) + ) + state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = ( + loaded[f"layers.{layer_i}.attention.k_normalization.bias"] + .view(dims_per_head // 2, 2) + .t() + .reshape(1, -1) + .repeat_interleave(num_key_value_heads, 0) + ) + + else: + # Sharded + state_dict.update( + { + f"model.layers.{layer_i}.input_layernorm.weight": torch.stack( + [l[f"layers.{layer_i}.attention_norm.weight"] for l in loaded] + ).mean(dim=0), + f"model.layers.{layer_i}.post_attention_layernorm.weight": torch.stack( + [l[f"layers.{layer_i}.ffn_norm.weight"] for l in loaded] + ).mean(dim=0), + } + ) + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( + torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim) + for i in range(num_shards) + ], + dim=0, + ).reshape(dim, dim), + n_heads=n_heads, + ) + + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( + torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wk.weight"].view( + num_local_key_value_heads, dims_per_head, dim + ) + for i in range(num_shards) + ], + dim=0, + ).reshape(key_value_dim, dim), + n_heads=num_key_value_heads, + dim1=key_value_dim, + ) + + # qk_layernorm (see https://github.com/huggingface/transformers/pull/31534#issuecomment-2207354677) + state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = ( + torch.cat([l[f"layers.{layer_i}.attention.q_normalization.weight"].unsqueeze(0) for l in loaded]) + .view(num_shards, dims_per_head // 2, 2) + .transpose(1, 2) + .reshape(num_shards, -1) + .repeat_interleave(n_heads // num_shards, 0) + ) + state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = ( + torch.cat([l[f"layers.{layer_i}.attention.q_normalization.bias"].unsqueeze(0) for l in loaded]) + .view(num_shards, dims_per_head // 2, 2) + .transpose(1, 2) + .reshape(num_shards, -1) + .repeat_interleave(n_heads // num_shards, 0) + ) + state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = ( + torch.cat([l[f"layers.{layer_i}.attention.k_normalization.weight"].unsqueeze(0) for l in loaded]) + .view(num_shards, dims_per_head // 2, 2) + .transpose(1, 2) + .reshape(num_shards, -1) + .repeat_interleave(num_key_value_heads // num_shards, 0) + ) + state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = ( + torch.cat([l[f"layers.{layer_i}.attention.k_normalization.bias"].unsqueeze(0) for l in loaded]) + .view(num_shards, dims_per_head // 2, 2) + .transpose(1, 2) + .reshape(num_shards, -1) + .repeat_interleave(num_key_value_heads // num_shards, 0) + ) + + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wv.weight"].view( + num_local_key_value_heads, dims_per_head, dim + ) + for i in range(num_shards) + ], + dim=0, + ).reshape(key_value_dim, dim) + + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1 + ) + state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 + ) + state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1 + ) + state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 + ) + + if num_shards == 1: + # Unsharded + state_dict.update( + { + "model.embed_tokens.weight": loaded["tok_embeddings.weight"], + "model.norm.weight": loaded["norm.weight"], + "lm_head.weight": loaded["output.weight"], + } + ) + else: + state_dict.update( + { + "model.embed_tokens.weight": torch.cat( + [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1 + ), + "model.norm.weight": torch.stack([loaded[i]["norm.weight"] for i in range(num_shards)]).mean(dim=0), + "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0), + } + ) + + # Load VQGAN weights + vqgan_path = os.path.join(input_base_path, "tokenizer/vqgan.ckpt") + vqgan_state_dict = torch.load(vqgan_path, map_location="cpu")["state_dict"] + for k, v in vqgan_state_dict.items(): + if "decoder" in k: + continue # we dont do image generation yet + state_dict[f"model.vqmodel.{k}"] = v + + # Write configs + ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1 + multiple_of = params["multiple_of"] if "multiple_of" in params else 256 + + with open(os.path.join(input_base_path, "tokenizer/text_tokenizer.json")) as tokenizer_file: + tokenizer_config = json.load(tokenizer_file) + vocabulary_map = tokenizer_config["model"]["vocab"] + vocabulary_map[""] = vocabulary_map[ + "" + ] # use a reserved token instead of adding a new one + del vocabulary_map[""] + + for token in tokenizer_config["added_tokens"]: + if token["content"] == "": + token["content"] = "" + + with open(os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), "w") as f: + json.dump(tokenizer_config, f) # save the new file to init tokenizer later + + vq_keys_to_replace = [ + ("ch", "base_channels"), + ("out_ch", "out_channels"), + ("n_embed", "num_embeddings"), + ("ch_mult", "channel_multiplier"), + ("double_z", "double_latent"), + ("z_channels", "latent_channels"), + ] + with open(os.path.join(input_base_path, "tokenizer/vqgan.yaml")) as vqgan_cfg_file: + vq_config = yaml.safe_load(vqgan_cfg_file)["model"]["params"] + vq_config.update(**vq_config["ddconfig"]) + for old, new in vq_keys_to_replace: + vq_config[new] = vq_config[old] + del vq_config["ddconfig"] + del vq_config["ckpt_path"] + del vq_config["lossconfig"] + + config = ChameleonConfig( + hidden_size=dim, + intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of), + num_attention_heads=params["n_heads"], + num_hidden_layers=params["n_layers"], + rms_norm_eps=params["norm_eps"], + num_key_value_heads=num_key_value_heads, + vocab_size=VOCAB_SIZE, + rope_theta=base, + max_position_embeddings=max_position_embeddings, + model_parallel_size=model_parallel_size, + swin_norm=swin_norm, + vq_config=vq_config, + vocabulary_map=vocabulary_map, + ) + with init_empty_weights(): + model = ChameleonForCausalLM(config) + + model.load_state_dict(state_dict, assign=True, strict=False) + model.save_pretrained(model_path, safe_serialization=True) + + # Load and save the processor + tokenizer = LlamaTokenizerFast( + tokenizer_file=os.path.join(input_base_path, "tokenizer/text_tokenizer_modified.json"), legacy=False + ) + tokenizer.sep_token_id = 8710 # assign to sep so that we can append it after input text + tokenizer.pad_token_id = 1 # assing to special pad_token + image_processor = ChameleonImageProcessor() + processor = ChameleonProcessor(image_processor=image_processor, tokenizer=tokenizer) + processor.save_pretrained(model_path) + + # Make space so we can load the model properly now. + del state_dict + del loaded + del vqgan_state_dict + gc.collect() + + # Short inference on a few examples to check if generation makes sense + # taken from https://github.com/facebookresearch/chameleon/blob/7a72f40aa5f462965c8374f25257f55b65b25ff4/data/prompts_for_human_evaluations.jsonl + print("Loading the checkpoint in a Chameleon model...") + print("*" * 100) + model = ChameleonForCausalLM.from_pretrained( + model_path, attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="auto" + ) + processor = ChameleonProcessor.from_pretrained(model_path) + + prompt = "I'm very intrigued by this work of art:Please tell me about the artist." + image = Image.open( + requests.get( + "https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True + ).raw + ) + inputs = processor(prompt, images=image, return_tensors="pt").to(model.device, torch.bfloat16) + length = inputs.input_ids.shape[1] + + out = model.generate(**inputs, max_new_tokens=40, do_sample=False) + generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0] + + print(f"Generation for single-image: {generated_text}") + print("*" * 100) + + # Multi-image example + prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation." + image = Image.open( + requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw + ) + image_2 = Image.open( + requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw + ) + + inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, dtype=torch.bfloat16) + length = inputs.input_ids.shape[1] + out = model.generate(**inputs, max_new_tokens=50, do_sample=False) + generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0] + + print(f"Generation for multi-image: {generated_text}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + help="Location of Chameleon weights", + ) + parser.add_argument( + "--model_size", + choices=["7B", "30B"], + help="" + " models correspond to the finetuned versions, and are specific to the Chameleon official release. For more details on Chameleon, checkout the original repo: https://huggingface.co/meta-chameleon", + ) + parser.add_argument( + "--output_dir", + help="Location to write HF model", + ) + parser.add_argument( + "--test_inference", + action="store_true", + help="Whether to load the model for generation to test it's converted correctly.", + ) + # Different Chameleon versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used. + parser.add_argument( + "--chameleon_version", + choices=[1], + default=1, + type=int, + help="Version of the Chameleon model to convert", + ) + args = parser.parse_args() + write_model( + model_path=args.output_dir, + input_base_path=args.input_dir, + model_size=args.model_size, + chameleon_version=args.chameleon_version, + ) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/chameleon/image_processing_chameleon.py b/src/transformers/models/chameleon/image_processing_chameleon.py new file mode 100644 index 00000000000000..021a1f5680c6bf --- /dev/null +++ b/src/transformers/models/chameleon/image_processing_chameleon.py @@ -0,0 +1,390 @@ +# coding=utf-8 +# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Chameleon.""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + is_valid_image, + to_numpy_array, + valid_images, + validate_kwargs, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_vision_available, logging + + +logger = logging.get_logger(__name__) + +if is_vision_available(): + import PIL + + +def make_batched_images(images) -> List[List[ImageInput]]: + """ + Accepts images in list or nested list format, and makes a list of images for preprocessing. + + Args: + images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): + The input image. + + Returns: + list: A list of images. + """ + if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): + return [img for img_list in images for img in img_list] + + elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): + return images + + elif is_valid_image(images): + return [images] + + raise ValueError(f"Could not make batched video from {images}") + + +class ChameleonImageProcessor(BaseImageProcessor): + r""" + Constructs a Chameleon image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 512}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to 1): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to {"height": 512, "width": 512}): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to 0.0078): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[1.0, 1.0, 1.0]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[1.0, 1.0, 1.0]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PIL.Image.LANCZOS, + do_center_crop: bool = True, + crop_size: Dict[str, int] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 0.0078, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 512} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 512, "width": 512} + crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else [1.0, 1.0, 1.0] + self.image_std = image_std if image_std is not None else [1.0, 1.0, 1.0] + self.do_convert_rgb = do_convert_rgb + self._valid_processor_keys = [ + "images", + "do_resize", + "size", + "resample", + "do_center_crop", + "crop_size", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_convert_rgb", + "return_tensors", + "data_format", + "input_data_format", + ] + + # Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + default_to_square = True + if "shortest_edge" in size: + size = size["shortest_edge"] + default_to_square = False + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + image, + size=size, + default_to_square=default_to_square, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_center_crop: bool = None, + crop_size: int = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + images = make_batched_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if do_convert_rgb: + images = [self.blend_rgba(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_center_crop: + images = [ + self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + def blend_rgba(self, image: ImageInput) -> ImageInput: + """ + Convert image to RGB by blending the transparency layer if it's in RGBA format. + If image is not `PIL.Image`, it si simply returned without modifications. + + Args: + image (`ImageInput`): + Image to convert. + """ + + if not isinstance(image, PIL.Image.Image): + return image + elif image.mode == "RGB": + return image + + img_rgba = np.array(image.convert("RGBA")) + + # If there is no transparency layer, simple convert and return. + if not (img_rgba[:, :, 3] < 255).any(): + return image.convert("RGB") + + # There is a transparency layer, blend it with a white background. + # Calculate the alpha proportion for blending. + alpha = img_rgba[:, :, 3] / 255.0 + img_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[:, :, np.newaxis] * img_rgba[:, :, :3] + return PIL.Image.fromarray(img_rgb.astype("uint8"), "RGB") diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py new file mode 100644 index 00000000000000..2cf7d5679138cf --- /dev/null +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -0,0 +1,1626 @@ +# coding=utf-8 +# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Chameleon model.""" + +import math +from functools import cached_property +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig + + +if is_flash_attn_2_available(): + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "ChameleonConfig" +_CHECKPOINT_FOR_DOC = "meta/chameleon-7b" +_EXPECTED_OUTPUT_SHAPE = [1, 7, 4096] +_SEQ_CLASS_EXPECTED_LOSS = 1.03 +_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Chameleon +class ChameleonRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + ChameleonRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(ChameleonRMSNorm) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon +class ChameleonRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Chameleon +class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding): + """ChameleonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Chameleon +class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding): + """ChameleonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->Chameleon +class ChameleonMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + # Ignore copy + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class ChameleonLayerNorm(nn.LayerNorm): + """ + LayerNorm but computes stats only over the last dim because Chameleon applies gamma and beta + from each shard separately to each head, instead of reducing. We can apply each head's own + gamma/beta by repeat-interleaving weights from each shard, but the stats have to be computed + in the last dimension. This module applies gamma/beta manually to fulfill this requirement. + """ + + def __init__(self, hidden_size, *args, **kwargs): + super().__init__(hidden_size, *args, **kwargs) + self.normalized_shape = (hidden_size[-1],) + + def forward(self, hidden_states): + hidden_states = F.layer_norm(hidden_states, self.normalized_shape, None, None, eps=1e-5) + hidden_states = hidden_states * self.weight + self.bias + return hidden_states + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class ChameleonAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: ChameleonConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.model_parallel_size = config.model_parallel_size + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim)) + self.k_norm = ChameleonLayerNorm((self.num_key_value_heads, self.head_dim)) + self._init_rope() + + # Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Chameleon + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = ChameleonRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = ChameleonLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = ChameleonDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) + query_states = self.q_norm(query_states) + + key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) + key_states = self.k_norm(key_states) + + query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon +class ChameleonFlashAttention2(ChameleonAttention): + """ + Chameleon flash attention module. This module inherits from `ChameleonAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + # Ignore copy + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) + query_states = self.q_norm(query_states) + + key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) + key_states = self.k_norm(key_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. + # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (ChameleonRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class ChameleonSdpaAttention(ChameleonAttention): + """ + Chameleon attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `ChameleonAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from ChameleonAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "ChameleonModel is using ChameleonSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) + query_states = self.q_norm(query_states) + + key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) + key_states = self.k_norm(key_states) + + query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None and cache_position is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +CHAMELEON_ATTENTION_CLASSES = { + "eager": ChameleonAttention, + "flash_attention_2": ChameleonFlashAttention2, + "sdpa": ChameleonSdpaAttention, +} + + +# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON +class ChameleonDecoderLayer(nn.Module): + def __init__(self, config: ChameleonConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = CHAMELEON_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = ChameleonMLP(config) + self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class ChameleonSwinDecoderLayer(nn.Module): + def __init__(self, config: ChameleonConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = CHAMELEON_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = ChameleonMLP(config) + self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + """ + + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = self.input_layernorm(hidden_states) + hidden_states = residual + hidden_states + # Fully Connected + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class ChameleonVQVAEVectorQuantizer(nn.Module): + """ + A module for vector quantization using learned embedding vectors. + + This module implements the quantization process similar to te one described in + the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous + input vectors into discrete codebook vectors, which are learned during training. + Current implementation improves over previous ones by avoiding costly matrix multiplications + and allowing for post-hoc remapping of indices. + """ + + def __init__(self, config): + super().__init__() + self.num_embeddings = config.num_embeddings + self.embedding_dim = config.embed_dim + self.beta = getattr(config, "beta", 0.25) + + self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) + self.re_embed = self.num_embeddings + + def forward(self, hidden_state: torch.Tensor): + hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() + hidden_state_flattened = hidden_state.view(-1, self.embedding_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + distances = ( + torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, self.embedding.weight.transpose(0, 1)) + ) + + min_encoding_indices = torch.argmin(distances, dim=1) + hidden_state_quant = self.embedding(min_encoding_indices).view(hidden_state.shape) + + # compute loss for embedding + loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean( + (hidden_state_quant - hidden_state.detach()) ** 2 + ) + + # preserve gradients + hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach() + + # reshape back to match original input shape + hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous() + + return hidden_state_quant, loss, min_encoding_indices + + +class ChameleonVQVAEEncoderConvDownsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, hidden_states): + # no asymmetric padding in torch conv, must do it ourselves + hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class ChameleonVQVAEEncoderResnetBlock(nn.Module): + def __init__( + self, + config, + in_channels, + out_channels=None, + conv_shortcut=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.dropout = torch.nn.Dropout(config.dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + residual = self.conv_shortcut(residual) + else: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +class ChameleonVQVAEEncoderAttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states): + residual = hidden_states + hidden_states = self.norm(hidden_states) + query_states = self.q(hidden_states) + key_states = self.k(hidden_states) + value_states = self.v(hidden_states) + + # compute attention + batch_size, channels, height, width = query_states.shape + query_states = query_states.reshape(batch_size, channels, height * width).permute(0, 2, 1) + key_states = key_states.reshape(batch_size, channels, height * width) + attn_weights = torch.bmm(query_states, key_states) + attn_weights = attn_weights * (int(channels) ** (-0.5)) + attn_weights = F.softmax(attn_weights, dim=2) + + # attend to values + value_states = value_states.reshape(batch_size, channels, height * width) + attn_weights = attn_weights.permute(0, 2, 1) + attn_output = torch.bmm(value_states, attn_weights).reshape(batch_size, channels, height, width) + + attn_output = self.proj_out(attn_output) + return residual + attn_output + + +class ChameleonVQVAEEncoder(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + base_channels = config.base_channels + resolution = config.resolution + in_channels = config.in_channels + double_latent = config.double_latent + latent_channels = config.latent_channels + channel_multiplier = config.channel_multiplier + + self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_channel_multiplier = (1,) + tuple(channel_multiplier) + self.in_channel_multiplier = in_channel_multiplier + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = base_channels * in_channel_multiplier[i_level] + block_out = base_channels * channel_multiplier[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ChameleonVQVAEEncoderResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_out, + ) + ) + block_in = block_out + if ( + config.attn_resolutions is not None + and curr_res in config.attn_resolutions + and config.attn_type == "vanilla" + ): + attn.append(ChameleonVQVAEEncoderAttnBlock(block_in)) + + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + self.mid = nn.Module() + self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_in, + ) + self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(block_in) if config.attn_type == "vanilla" else nn.Identity() + self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_in, + ) + + self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * latent_channels if double_latent else latent_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, pixel_values: torch.LongTensor): + # downsampling + hidden_states = [self.conv_in(pixel_values)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + hidden_state = self.down[i_level].block[i_block]( + hidden_states[-1], + ) + if len(self.down[i_level].attn) > 0: + hidden_state = self.down[i_level].attn[i_block](hidden_state) + hidden_states.append(hidden_state) + if i_level != self.num_resolutions - 1: + hidden_states.append(self.down[i_level].downsample(hidden_states[-1])) + + # middle + last_hidden_state = hidden_states[-1] + last_hidden_state = self.mid.block_1(last_hidden_state) + last_hidden_state = self.mid.attn_1(last_hidden_state) + last_hidden_state = self.mid.block_2(last_hidden_state) + + # end + last_hidden_state = self.norm_out(last_hidden_state) + last_hidden_state *= torch.sigmoid(last_hidden_state) + last_hidden_state = self.conv_out(last_hidden_state) + return last_hidden_state + + +CHAMELEON_VQ_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ChameleonVQVAEConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + """The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens. + This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from + [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131). + """, + CHAMELEON_VQ_START_DOCSTRING, +) +class ChameleonVQVAE(PreTrainedModel): + config_class = ChameleonVQVAEConfig + _no_split_modules = ["ChameleonVQVAEVectorQuantizer"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + elif isinstance(module, nn.GroupNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + + def __init__(self, config: ChameleonVQVAEConfig): + super().__init__(config) + + self.encoder = ChameleonVQVAEEncoder(config) + self.quantize = ChameleonVQVAEVectorQuantizer(config) + self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1) + self.eval() # Chameleon's VQ model is frozen + + def encode(self, pixel_values: torch.LongTensor): + hidden_states = self.encoder(pixel_values) + hidden_states = self.quant_conv(hidden_states) + quant, emb_loss, indices = self.quantize(hidden_states) + return quant, emb_loss, indices + + +class ChameleonImageVocabularyMapping: + """ + A class for mapping discrete image tokens from VQGAN to BPE tokens. + """ + + def __init__(self, vocab_map): + self.vocab_map = vocab_map + self.image_token_id = vocab_map.get("") + + @cached_property + def val2name(self): + return {v: k for k, v in self.vocab_map.items()} + + @cached_property + def image_tokens(self): + return sorted([val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")]) + + @cached_property + def bpe2img(self): + img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)} + + def remap(old_name: str) -> str: + return "".join(img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1]) + + return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens} + + @cached_property + def img2bpe(self): + return {v: k for k, v in self.bpe2img.items()} + + @cached_property + def bpe2img_search_tensors(self): + return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(sorted(self.bpe2img.values())) + + @cached_property + def img2bpe_mapping_tensor(self): + mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int) + for k, v in self.img2bpe.items(): + mapping[k] = v + return mapping + + def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor: + device = img_batch.device + img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")] + return img_tokens.to(device) + + +CHAMELEON_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ChameleonConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare chameleon Model outputting raw hidden-states without any specific head on top.", + CHAMELEON_START_DOCSTRING, +) +class ChameleonPreTrainedModel(PreTrainedModel): + config_class = ChameleonConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"] + _skip_keys_device_placement = ["past_key_values", "causal_mask"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_quantized_cache = True + _supports_cache_class = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, ChameleonVQVAE): + module.apply(module._init_weights) + elif isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +CHAMELEON_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`ChameleonImageProcessor.__call__`] for details. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Should always be a [`~cache_utils.Cache`] instance and the model will output the same cache instance. + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare chameleon Model outputting raw hidden-states without any specific head on top.", + CHAMELEON_START_DOCSTRING, +) +class ChameleonModel(ChameleonPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ChameleonDecoderLayer`] + + Args: + config: ChameleonConfig + """ + + def __init__(self, config: ChameleonConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map) + decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm else ChameleonSwinDecoderLayer + self.layers = nn.ModuleList( + [decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.vqmodel = ChameleonVQVAE(config.vq_config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def get_image_tokens(self, pixel_values: torch.FloatTensor): + """ + Tokenizes images into discrete tokens with VQGAN module. Converts + obtained image tokens into BPE tokens and wraps with "boi" and "eoi" + special tokens. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. + """ + batch_size = pixel_values.shape[0] + _, _, image_toks = self.vqmodel.encode(pixel_values) + bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks) + bpe_toks = bpe_toks.view(batch_size, -1) + return bpe_toks + + @add_start_docstrings_to_model_forward(CHAMELEON_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None: + image_tokens = self.get_image_tokens(pixel_values) + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + input_ids[special_image_mask] = image_tokens.flatten().to(input_ids.device, input_ids.dtype) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +@add_start_docstrings( + "Chameleon Model with a head on top used for outputting logits for next token prediction.", + CHAMELEON_START_DOCSTRING, +) +class ChameleonForCausalLM(ChameleonPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = ChameleonModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(CHAMELEON_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import ChameleonProcessor, ChameleonForCausalLM + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> model = ChameleonForCausalLM.from_pretrained("facebook/chameleon-7b", torch_dtype=torch.bfloat16) + >>> processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b") + + >>> prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation." + >>> image = Image.open(requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw) + >>> image_2 = Image.open(requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw) + + >>> inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, torch.bfloat16) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + # Disallow image tokens which does not include special begin-image and end-image tokens + image_tokens = self.model.vocabulary_mapping.image_tokens + logits[:, :, image_tokens] = torch.finfo(logits.dtype).min + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + pixel_values=None, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be `None` because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs diff --git a/src/transformers/models/chameleon/processing_chameleon.py b/src/transformers/models/chameleon/processing_chameleon.py new file mode 100644 index 00000000000000..559cac62e3d5a7 --- /dev/null +++ b/src/transformers/models/chameleon/processing_chameleon.py @@ -0,0 +1,162 @@ +# coding=utf-8 +# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Chameleon. +""" + +from typing import List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from ...utils import TensorType + + +class ChameleonProcessor(ProcessorMixin): + r""" + Constructs a Chameleon processor which wraps a Chameleon image processor and a Chameleon tokenizer into a single + processor. + + [`ChameleonProcessor`] offers all the functionalities of [`ChameleonImageProcessor`] and [`LlamaTokenizerFast`]. + See the [`~ChameleonProcessor.__call__`] and [`~ChameleonProcessor.decode`] for more information. + + Args: + image_processor ([`ChameleonImageProcessor`]): + The image processor is a required input. + tokenizer ([`LlamaTokenizerFast`]): + The tokenizer is a required input. + image_seq_length (`int`, *optional*, defaults to 1024): + Sequence length of one image embedding. + image_token (`str`, *optional*, defaults to `""`): + The special token used to indicate image in the text. + """ + + attributes = ["image_processor", "tokenizer"] + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + image_processor_class = "ChameleonImageProcessor" + + def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = ""): + self.image_seq_length = image_seq_length + self.image_token = image_token + self.image_start_token = "" # fixed tokens for start and end, so can hardcode + self.image_end_token = "" + super().__init__(image_processor, tokenizer) + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + images: ImageInput = None, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: int = None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + return_for_text_completion: bool = False, + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, or a list of strings") + + # Replace the image token with the expanded image token sequence + prompt_strings = [] + one_img_tokens = self.image_start_token + (self.image_token * self.image_seq_length) + self.image_end_token + for sample in text: + sample = sample.replace(self.image_token, one_img_tokens) + if not return_for_text_completion: + sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode + prompt_strings.append(sample) + + data = self.tokenizer( + prompt_strings, + return_tensors=return_tensors, + padding=padding, + truncation=truncation, + max_length=max_length, + ) + + if images is not None: + pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] + data["pixel_values"] = pixel_values + + return BatchFeature(data=data, tensor_type=return_tensors) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index eb9252fc9863f3..725d35b0096f7f 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1835,6 +1835,41 @@ def load_tf_weights_in_canine(*args, **kwargs): requires_backends(load_tf_weights_in_canine, ["torch"]) +class ChameleonForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ChameleonModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ChameleonPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ChameleonProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ChameleonVQVAE(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ChineseCLIPModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 9d5175ed2aeab9..19f8dc1b1d9c9e 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -58,6 +58,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class ChameleonImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class ChineseCLIPFeatureExtractor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/chameleon/__init__.py b/tests/models/chameleon/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/chameleon/test_image_processing_chameleon.py b/tests/models/chameleon/test_image_processing_chameleon.py new file mode 100644 index 00000000000000..cf39e1e17fce24 --- /dev/null +++ b/tests/models/chameleon/test_image_processing_chameleon.py @@ -0,0 +1,205 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + from transformers import ChameleonImageProcessor + + +class ChameleonImageProcessingTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=200, + do_resize=True, + size=None, + do_center_crop=True, + crop_size=None, + do_normalize=True, + image_mean=[1.0, 1.0, 1.0], + image_std=[1.0, 1.0, 1.0], + do_convert_rgb=True, + ): + size = size if size is not None else {"shortest_edge": 18} + crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18} + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + + def prepare_image_processor_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "do_center_crop": self.do_center_crop, + "crop_size": self.crop_size, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_convert_rgb": self.do_convert_rgb, + } + + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.expected_output_image_shape + def expected_output_image_shape(self, images): + return self.num_channels, self.crop_size["height"], self.crop_size["width"] + + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + return prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + + +@require_torch +@require_vision +class ChameleonImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = ChameleonImageProcessor if is_vision_available() else None + + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->Chameleon + def setUp(self): + super().setUp() + self.image_processor_tester = ChameleonImageProcessingTester(self) + + @property + # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_center_crop")) + self.assertTrue(hasattr(image_processing, "center_crop")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + + def test_image_processor_from_dict_with_kwargs(self): + image_processor = self.image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 18}) + self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18}) + + image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84) + self.assertEqual(image_processor.size, {"shortest_edge": 42}) + self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84}) + + def test_call_pil(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + def test_call_numpy(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + def test_call_pytorch(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True) + + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = (1, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + def test_nested_input(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) + + # Test batched as a list of images + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 3, 18, 18) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched as a nested list of images, where each sublist is one batch + image_inputs_nested = [image_inputs[:3], image_inputs[3:]] + encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values + expected_output_image_shape = (7, 3, 18, 18) + self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape) + + # Image processor should return same pixel values, independently of input format + self.assertTrue((encoded_images_nested == encoded_images).all()) diff --git a/tests/models/chameleon/test_modeling_chameleon.py b/tests/models/chameleon/test_modeling_chameleon.py new file mode 100644 index 00000000000000..7e3b688c93d209 --- /dev/null +++ b/tests/models/chameleon/test_modeling_chameleon.py @@ -0,0 +1,448 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch chameleon model.""" + +import unittest + +import pytest +import requests +from parameterized import parameterized + +from transformers import ChameleonConfig, is_torch_available, is_vision_available, set_seed +from transformers.testing_utils import ( + require_bitsandbytes, + require_flash_attn, + require_read_token, + require_torch, + require_torch_gpu, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_vision_available(): + from PIL import Image + +if is_torch_available(): + import torch + + from transformers import ( + ChameleonForCausalLM, + ChameleonModel, + ChameleonProcessor, + ) + + +class ChameleonModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=False, + use_input_mask=True, + use_labels=True, + vocab_size=99, + image_token_id=98, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + pad_token_id=0, + vq_num_embeds=12, + vq_embed_dim=12, + vq_channel_multiplier=[1, 2], + vq_img_token_start_id=10, # has to be less than vocab size when added with vq_num_embeds + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.image_token_id = image_token_id + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.pad_token_id = pad_token_id + self.scope = scope + self.vq_num_embeds = vq_num_embeds + self.vq_embed_dim = vq_embed_dim + self.vq_channel_multiplier = vq_channel_multiplier + self.vq_img_token_start_id = vq_img_token_start_id + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + # create dummy vocab map for image2bpe mapping if it needs remapping + # we assume that vocab size is big enough to accoun for image tokens somewhere in the beginning + # same way as in real ckpt, when img tokens are in first half of embeds + # we will need "vq_num_embeds" amount of tokens + + vocab_map = {i: chr(i) for i in range(self.vocab_size)} + vocab_map[self.image_token_id] = "" + start = self.vq_img_token_start_id + end = self.vq_img_token_start_id + self.vq_num_embeds + for i in range(start, end): + vocab_map[i] = f"IMGIMGBS{i}" # dummy str for each token, anything starting with IMGIMG + + return ChameleonConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + vocabulary_map={v: k for k, v in vocab_map.items()}, + vq_config=self.get_vq_config(), + ) + + def get_vq_config(self): + return { + "embed_dim": self.vq_embed_dim, + "num_embeddings": self.vq_num_embeds, + "latent_channels": self.vq_embed_dim, + "in_channels": 3, + "base_channels": 32, # we have a GroupNorm of 32 groups, so can't do less + "channel_multiplier": self.vq_channel_multiplier, + } + + def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): + model = ChameleonModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_causal_lm( + self, + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + model = ChameleonForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.is_decoder = True + model = ChameleonForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (ChameleonModel, ChameleonForCausalLM) if is_torch_available() else () + all_generative_model_classes = (ChameleonForCausalLM,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": ChameleonModel, + "text-generation": ChameleonForCausalLM, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + fx_compatible = False + + def setUp(self): + self.model_tester = ChameleonModelTester(self) + self.config_tester = ConfigTester(self, config_class=ChameleonConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @parameterized.expand([("linear",), ("dynamic",)]) + def test_model_rope_scaling(self, scaling_type): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = ChameleonModel(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + scaled_model = ChameleonModel(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state + + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + + @require_flash_attn + @require_read_token + @require_torch_gpu + @require_bitsandbytes + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_generate_padding_right(self): + """ + Overwritting the common test as the test is flaky on tiny models + """ + model = ChameleonForCausalLM.from_pretrained( + "facebook/chameleon-7b", + load_in_4bit=True, + device_map={"": 0}, + ) + + processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b") + texts = ["hi", "Hello this is a very long sentence"] + + processor.tokenizer.padding_side = "right" + + inputs = processor(texts, return_tensors="pt", padding=True).to(0) + + output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_native = processor.tokenizer.batch_decode(output_native) + + model = ChameleonForCausalLM.from_pretrained( + "facebook/chameleon-7b", + load_in_4bit=True, + attn_implementation="flash_attention_2", + ) + + output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_fa_2 = processor.tokenizer.batch_decode(output_fa_2) + + self.assertListEqual(output_native, output_fa_2) + + @unittest.skip("Chameleon forces some token ids to be -inf!") + def test_batching_equivalence(self): + pass + + +@require_torch +class ChameleonIntegrationTest(unittest.TestCase): + @slow + @require_bitsandbytes + @require_read_token + def test_model_7b(self): + model = ChameleonForCausalLM.from_pretrained("facebook/chameleon-7b", load_in_4bit=True, device_map="auto") + processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b") + + image = Image.open( + requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw + ) + prompt = "Describe what do you see here and tell me about the history behind it?" + + inputs = processor(prompt, images=image, return_tensors="pt").to(model.device, torch.float16) + + # greedy generation outputs + EXPECTED_TEXT_COMPLETION = ['Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue line extending across the center of the image. The line is labeled "390 light years" and is accompanied by a small black and'] # fmt: skip + generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False) + text = processor.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @slow + @require_bitsandbytes + @require_read_token + def test_model_7b_batched(self): + model = ChameleonForCausalLM.from_pretrained("facebook/chameleon-7b", load_in_4bit=True, device_map="auto") + processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b") + + image = Image.open( + requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw + ) + image_2 = Image.open( + requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw + ) + prompts = [ + "Describe what do you see here and tell me about the history behind it?", + "What constellation is this image showing?", + ] + + inputs = processor(prompts, images=[image, image_2], padding=True, return_tensors="pt").to( + model.device, torch.float16 + ) + + # greedy generation outputs + EXPECTED_TEXT_COMPLETION = [ + 'Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue dot in the center representing the star Alpha Centauri. The star map is a representation of the night sky, showing the positions of stars in', + 'What constellation is this image showing?The image is showing the constellation of Orion.' + ] # fmt: skip + generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False) + text = processor.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @slow + @require_bitsandbytes + @require_read_token + def test_model_7b_multi_image(self): + model = ChameleonForCausalLM.from_pretrained("facebook/chameleon-7b", load_in_4bit=True, device_map="auto") + processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b") + + image = Image.open( + requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw + ) + image_2 = Image.open( + requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw + ) + prompt = "What do these two images have in common?" + + inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, torch.float16) + + # greedy generation outputs + EXPECTED_TEXT_COMPLETION = ['What do these two images have in common?The two images show a connection between two things that are not necessarily related. The first image shows a group of stars, while the second image shows a network of lines connecting two points. The connection between'] # fmt: skip + generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False) + text = processor.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a73417e4164821..dd041188cdca3d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -259,9 +259,11 @@ def check_save_load(out1, out2): # make sure we don't have nans out_2 = out2.cpu().numpy() out_2[np.isnan(out_2)] = 0 + out_2 = out_2[~np.isneginf(out_2)] out_1 = out1.cpu().numpy() out_1[np.isnan(out_1)] = 0 + out_1 = out_1[~np.isneginf(out_1)] max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 1e-5) @@ -660,6 +662,8 @@ def check_determinism(first, second): out_2 = second.cpu().numpy() out_1 = out_1[~np.isnan(out_1)] out_2 = out_2[~np.isnan(out_2)] + out_1 = out_1[~np.isneginf(out_1)] + out_2 = out_2[~np.isneginf(out_2)] max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 1e-5) diff --git a/utils/check_repo.py b/utils/check_repo.py index 5e6b42578313d5..293089ccb662b4 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -127,6 +127,7 @@ "SeamlessM4TTextToUnitModel", # Building part of bigger (tested) model. "SeamlessM4TCodeHifiGan", # Building part of bigger (tested) model. "SeamlessM4TTextToUnitForConditionalGeneration", # Building part of bigger (tested) model. + "ChameleonVQVAE", # VQVAE here is used only for encoding (discretizing) and is tested as part of bigger model ] # Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't @@ -319,6 +320,7 @@ "SegGptForImageSegmentation", "SiglipVisionModel", "SiglipTextModel", + "ChameleonVQVAE", # no autoclass for VQ-VAE models ] # DO NOT edit this list!