Skip to content

Commit

Permalink
Chameleon: add model (#31534)
Browse files Browse the repository at this point in the history
* Chameleon model integration

Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com>
Co-authored-by: Leonid Shamis <leonid.shamis@gmail.com>

* fix 7B, again. mask away image tokens

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* remove pretrained_config_map

* make fixup passing up to utils/check_config_docstrings.py; vqgan moved to the modeling file

* remove tokenizer (use llama's); remove codechameleon tests

* a few copied from statements and minor changes

* copied from in ChameleonModel

* some copies in ChameleonForCausalLM

* a few more copies

* VQModel moved to ChameleonModel (as opposed to being in the processor)

* ChameleonProcessor ready

* Fix chameleon weights convert

* update conversion script

* clean-up processing

* update modeling a bit

* update

* update (throws error...)

* correct conversion ready

* fix tests

* fix docs

* docs

* ve swin norm

* fix device for vocab map

* add normalization

* update

* update script with rope rotations

* final fix on model conversion

* add slow tests

* more info in docs

* fix repo consistency tests

* fix repo tests

* fix-copies

* hope this will make CI happy

* fix for 30b model

* Update docs/source/en/index.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update docs/source/en/model_doc/chameleon.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/chameleon/modeling_chameleon.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update docs/source/en/model_doc/chameleon.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update docs/source/en/model_doc/chameleon.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update docs/source/en/model_doc/chameleon.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update docs/source/en/model_doc/chameleon.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/auto/configuration_auto.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/chameleon/image_processing_chameleon.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/chameleon/image_processing_chameleon.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/chameleon/image_processing_chameleon.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/chameleon/image_processing_chameleon.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/chameleon/modeling_chameleon.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/chameleon/processing_chameleon.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/chameleon/processing_chameleon.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/chameleon/test_modeling_chameleon.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/chameleon/test_modeling_chameleon.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/chameleon/test_modeling_chameleon.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* address comments

* remove assertion in conversion script

* add image processor test

* not copied

* port changes for qk layernorm

* fix-copies

* read token decorator for tests

* [run-slow] chameleon

* one more read-token

* address some comments

* qk norm changes

* tests and repo check

* moved rope permutations to conversion, YAY!

* fix past kv check

* docs

* layernorm done!

* let's be consistent in naming

* fix slow tests

* weird thing with slow CI, but let's see

* once more try

* remove past-kv as tuple following llama

* ignore

* style

---------

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
Co-authored-by: ArthurZucker <arthur.zucker@gmail.com>
Co-authored-by: jacobkahn <jacobkahn1@gmail.com>
Co-authored-by: Leonid Shamis <leonid.shamis@gmail.com>
Co-authored-by: Leonid Shamis <lshamis@meta.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Joao Gante <joao@huggingface.co>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
10 people authored and itazap committed Jul 25, 2024
1 parent ac4e328 commit 5b2a504
Show file tree
Hide file tree
Showing 24 changed files with 3,950 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@
title: CamemBERT
- local: model_doc/canine
title: CANINE
- local: model_doc/chameleon
title: chameleon
- local: model_doc/codegen
title: CodeGen
- local: model_doc/code_llama
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ Flax), PyTorch, and/or TensorFlow.
| [ByT5](model_doc/byt5) ||||
| [CamemBERT](model_doc/camembert) ||||
| [CANINE](model_doc/canine) ||||
| [Chameleon](model_doc/chameleon) ||||
| [Chinese-CLIP](model_doc/chinese_clip) ||||
| [CLAP](model_doc/clap) ||||
| [CLIP](model_doc/clip) ||||
Expand Down
189 changes: 189 additions & 0 deletions docs/source/en/model_doc/chameleon.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Chameleon

## Overview

The Chameleon model was proposed in [Chameleon: Mixed-Modal Early-Fusion Foundation Models
](https://arxiv.org/abs/2405.09818v1) by META AI Chameleon Team. Chameleon is a Vision-Language Model that use vector quantization to tokenize images which enables the model to generate multimodal output. The model takes images and texts as input, including an interleaved format, and generates textual response. Image generation module is not released yet.


The abstract from the paper is the following:

*We present Chameleon, a family of early-fusion token-based mixed-modal models capable of understanding and generating images and text in any arbitrary sequence. We outline a stable training
approach from inception, an alignment recipe, and an architectural parameterization tailored for the
early-fusion, token-based, mixed-modal setting. The models are evaluated on a comprehensive range
of tasks, including visual question answering, image captioning, text generation, image generation, and
long-form mixed modal generation. Chameleon demonstrates broad and general capabilities, including
state-of-the-art performance in image captioning tasks, outperforms Llama-2 in text-only tasks while
being competitive with models such as Mixtral 8x7B and Gemini-Pro, and performs non-trivial image
generation, all in a single model. It also matches or exceeds the performance of much larger models,
including Gemini Pro and GPT-4V, according to human judgments on a new long-form mixed-modal
generation evaluation, where either the prompt or outputs contain mixed sequences of both images and
text. Chameleon marks a significant step forward in a unified modeling of full multimodal documents*


<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/chameleon_arch.png"
alt="drawing" width="600"/>

<small> Chameleon incorporates a vector quantizer module to transform images into discrete tokens. That also enables image geenration using an auto-regressive transformer. Taken from the <a href="https://arxiv.org/abs/2405.09818v1">original paper.</a> </small>

This model was contributed by [joaogante](https://huggingface.co/joaogante) and [RaushanTurganbay](https://huggingface.co/RaushanTurganbay).
The original code can be found [here](https://github.com/facebookresearch/chameleon).


## Usage tips

- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to set `processor.tokenizer.padding_side = "left"` before generating.

- Note that Chameleon was tuned for safety alignment. If the model is refusing to answer, consider asking a more concrete question, instead of an open question.

- Chameleon generates in chat format which means that the generated text will always be the "assistant's turn". You can enable a text completion generation by passing `return_for_text_completion=True` when calling the processor.

> [!NOTE]
> Chameleon implementation in Transformers uses a special image token to indicate where to merge image embeddings. For special image token we didn't add a new one but used one of the reserved tokens: `<reserved08707>`.
## Usage example

### Single image inference

Here's how to load the model and perform inference in half-precision (`torch.float16`):

```python
from transformers import ChameleonProcessor, ChameleonForCausalLM
import torch
from PIL import Image
import requests

processor = ChameleonProcessor.from_pretrained("meta-chameleon")
model = ChameleonForCausalLM.from_pretrained("meta-chameleon", torch_dtype=torch.float16, device_map="auto")

# prepare image and text prompt
url = "https://bjiujitsu.com/wp-content/uploads/2021/01/jiu_jitsu_belt_white_1.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "What color is the belt in this image?<image>"

inputs = processor(prompt, image, return_tensors="pt").to(model.device)

# autoregressively complete prompt
output = model.generate(**inputs, max_new_tokens=50)
print(processor.decode(output[0], skip_special_tokens=True))
```

### Multi image inference

Chameleon can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). Here is how you can do it:

```python
from transformers import ChameleonProcessor, ChameleonForCausalLM
import torch
from PIL import Image
import requests

processor = ChameleonProcessor.from_pretrained("meta-chameleon")
model = ChameleonForCausalLM.from_pretrained("meta-chameleon", torch_dtype=torch.float16, device_map="auto")

# Get three different images
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image_stop = Image.open(requests.get(url, stream=True).raw)

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image_cats = Image.open(requests.get(url, stream=True).raw)

url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
image_snowman = Image.open(requests.get(url, stream=True).raw)

# Prepare a batched prompt, where the first one is a multi-image prompt and the second is not
prompts = [
"What do these images have in common?<image><image>",
"<image>What is shown in this image?"
]

# We can simply feed images in the order they have to be used in the text prompt
# Each "<image>" token uses one image leaving the next for the subsequent "<image>" tokens
inputs = processor(text=prompts, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(model.device)

# Generate
generate_ids = model.generate(**inputs, max_new_tokens=50)
processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
```

## Model optimization

### Quantization using Bitsandbytes

The model can be loaded in 8 or 4 bits, greatly reducing the memory requirements while maintaining the performance of the original model. First make sure to install bitsandbytes, `pip install bitsandbytes` and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with:

```python
from transformers import ChameleonForCausalLM, BitsAndBytesConfig

# specify how to quantize the model
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)

model = ChameleonForCausalLM.from_pretrained("meta-chameleon", quantization_config=quantization_config, device_map="auto")
```

### Use Flash-Attention 2 and SDPA to further speed-up generation

The models supports both, Flash-Attention 2 and PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) which can be enables for optimization. SDPA is the default options when you load the model, If you want to switch for Flash Attention 2, first make sure to install flash-attn. Refer to the [original repository](https://github.com/Dao-AILab/flash-attention) regarding that package installation. Simply change the snippet above with:

```python
from transformers import ChameleonForCausalLM

model = ChameleonForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
attn_implementation="flash_attention_2"
).to(0)
```

## ChameleonConfig

[[autodoc]] ChameleonConfig

## ChameleonVQVAEConfig

[[autodoc]] ChameleonVQVAEConfig

## ChameleonProcessor

[[autodoc]] ChameleonProcessor

## ChameleonImageProcessor

[[autodoc]] ChameleonImageProcessor
- preprocess

## ChameleonVQVAE

[[autodoc]] ChameleonVQVAE
- forward

## ChameleonModel

[[autodoc]] ChameleonModel
- forward

## ChameleonForCausalLM

[[autodoc]] ChameleonForCausalLM
- forward
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ FlashAttention-2 is experimental and may change considerably in future versions.
FlashAttention-2 is currently supported for the following architectures:
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
Expand Down Expand Up @@ -198,6 +199,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,11 @@
"CanineConfig",
"CanineTokenizer",
],
"models.chameleon": [
"ChameleonConfig",
"ChameleonProcessor",
"ChameleonVQVAEConfig",
],
"models.chinese_clip": [
"ChineseCLIPConfig",
"ChineseCLIPProcessor",
Expand Down Expand Up @@ -1125,6 +1130,7 @@
_import_structure["models.bit"].extend(["BitImageProcessor"])
_import_structure["models.blip"].extend(["BlipImageProcessor"])
_import_structure["models.bridgetower"].append("BridgeTowerImageProcessor")
_import_structure["models.chameleon"].append("ChameleonImageProcessor")
_import_structure["models.chinese_clip"].extend(["ChineseCLIPFeatureExtractor", "ChineseCLIPImageProcessor"])
_import_structure["models.clip"].extend(["CLIPFeatureExtractor", "CLIPImageProcessor"])
_import_structure["models.conditional_detr"].extend(
Expand Down Expand Up @@ -1608,6 +1614,15 @@
"load_tf_weights_in_canine",
]
)
_import_structure["models.chameleon"].extend(
[
"ChameleonForCausalLM",
"ChameleonModel",
"ChameleonPreTrainedModel",
"ChameleonProcessor",
"ChameleonVQVAE",
]
)
_import_structure["models.chinese_clip"].extend(
[
"ChineseCLIPModel",
Expand Down Expand Up @@ -4890,6 +4905,11 @@
CanineConfig,
CanineTokenizer,
)
from .models.chameleon import (
ChameleonConfig,
ChameleonProcessor,
ChameleonVQVAEConfig,
)
from .models.chinese_clip import (
ChineseCLIPConfig,
ChineseCLIPProcessor,
Expand Down Expand Up @@ -5807,6 +5827,7 @@
from .models.bit import BitImageProcessor
from .models.blip import BlipImageProcessor
from .models.bridgetower import BridgeTowerImageProcessor
from .models.chameleon import ChameleonImageProcessor
from .models.chinese_clip import (
ChineseCLIPFeatureExtractor,
ChineseCLIPImageProcessor,
Expand Down Expand Up @@ -6254,6 +6275,13 @@
CaninePreTrainedModel,
load_tf_weights_in_canine,
)
from .models.chameleon import (
ChameleonForCausalLM,
ChameleonModel,
ChameleonPreTrainedModel,
ChameleonProcessor,
ChameleonVQVAE,
)
from .models.chinese_clip import (
ChineseCLIPModel,
ChineseCLIPPreTrainedModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
byt5,
camembert,
canine,
chameleon,
chinese_clip,
clap,
clip,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
("bros", "BrosConfig"),
("camembert", "CamembertConfig"),
("canine", "CanineConfig"),
("chameleon", "ChameleonConfig"),
("chinese_clip", "ChineseCLIPConfig"),
("chinese_clip_vision_model", "ChineseCLIPVisionConfig"),
("clap", "ClapConfig"),
Expand Down Expand Up @@ -329,6 +330,7 @@
("byt5", "ByT5"),
("camembert", "CamemBERT"),
("canine", "CANINE"),
("chameleon", "Chameleon"),
("chinese_clip", "Chinese-CLIP"),
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
("clap", "CLAP"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
("blip", ("BlipImageProcessor",)),
("blip-2", ("BlipImageProcessor",)),
("bridgetower", ("BridgeTowerImageProcessor",)),
("chameleon", ("ChameleonImageProcessor",)),
("chinese_clip", ("ChineseCLIPImageProcessor",)),
("clip", ("CLIPImageProcessor",)),
("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
("bros", "BrosModel"),
("camembert", "CamembertModel"),
("canine", "CanineModel"),
("chameleon", "ChameleonModel"),
("chinese_clip", "ChineseCLIPModel"),
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
("clap", "ClapModel"),
Expand Down Expand Up @@ -445,6 +446,7 @@
("blenderbot-small", "BlenderbotSmallForCausalLM"),
("bloom", "BloomForCausalLM"),
("camembert", "CamembertForCausalLM"),
("chameleon", "ChameleonForCausalLM"),
("code_llama", "LlamaForCausalLM"),
("codegen", "CodeGenForCausalLM"),
("cohere", "CohereForCausalLM"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
("blip", "BlipProcessor"),
("blip-2", "Blip2Processor"),
("bridgetower", "BridgeTowerProcessor"),
("chameleon", "ChameleonProcessor"),
("chinese_clip", "ChineseCLIPProcessor"),
("clap", "ClapProcessor"),
("clip", "CLIPProcessor"),
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@
),
),
("canine", ("CanineTokenizer", None)),
(
"chameleon",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
("chinese_clip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
(
"clap",
Expand Down
Loading

0 comments on commit 5b2a504

Please sign in to comment.