Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llama3-llava-next-8b to llava_next conversion script #31395

Merged
11 changes: 11 additions & 0 deletions docs/source/en/model_doc/llava_next.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ print(text_prompt)
"<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n"
```

[llama3-llava-next-8b-hf](https://huggingface.co/llava-hf/llava-next-8b-hf) requires the following format:

```bash
"<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|><|start_header_id|>user<|end_header_id|>\n\n<image>\nWhat is shown in this image?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
```

[llava-next-72b-hf](https://huggingface.co/llava-hf/llava-next-72b-hf) and [llava-next-110b-hf](https://huggingface.co/llava-hf/llava-next-110b-hf) require the following format:

```bash
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|>\n<|im_start|>assistant\n"
```

## Usage example

Expand Down
125 changes: 90 additions & 35 deletions src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"""

import argparse
import gc
import glob
import json
from pathlib import Path
Expand Down Expand Up @@ -111,6 +112,16 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
elif model_id == "liuhaotian/llava-v1.6-34b":
text_model_id = "NousResearch/Nous-Hermes-2-Yi-34B"
image_token_index = 64000
elif model_id == "lmms-lab/llama3-llava-next-8b":
text_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
image_token_index = 128256
elif model_id == "lmms-lab/llava-next-72b":
text_model_id = "Qwen/Qwen1.5-72B-Chat"
image_token_index = 151646
elif model_id == "lmms-lab/llava-next-110b":
text_model_id = "Qwen/Qwen1.5-110B-Chat"
image_token_index = 151646

vision_model_id = data["mm_vision_tower"]

torch.set_default_dtype(torch.float16)
Expand All @@ -120,7 +131,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
tokenizer = AutoTokenizer.from_pretrained(text_model_id, use_fast=use_fast)
tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True)

if model_id == "liuhaotian/llava-v1.6-mistral-7b":
if model_id in ("liuhaotian/llava-v1.6-mistral-7b", "lmms-lab/llama3-llava-next-8b"):
# Mistral-7B doesn't have a padding token set yet
tokenizer.add_special_tokens({"pad_token": "<pad>"})

Expand Down Expand Up @@ -151,28 +162,45 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):

# We add an image token so we resize the model
# Pad to 64 for performance reasons
pad_shape = 64
vocab_size = config.text_config.vocab_size
if model_id == "liuhaotian/llava-v1.6-34b":
# this one has 3 additional tokens, namely <|startoftext|>, <|endoftext|> and <image>
num_tokens = vocab_size + 3
else:
# this one has 2 additional tokens, namely <image> and <pad>
num_tokens = vocab_size + 2
model.resize_token_embeddings(num_tokens, pad_to_multiple_of=pad_shape)
model.language_model.model.embed_tokens.weight.data[vocab_size:] = torch.stack(
tuple(
(dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[vocab_size:].shape[0]))
),
dim=0,
)
model.language_model.lm_head.weight.data[vocab_size:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[vocab_size:].shape[0]))),
dim=0,
)
# Qwen-based models have extra unused space in the vocab size already, so no need to resize
if model_id not in ["lmms-lab/llava-next-72b", "lmms-lab/llava-next-110b"]:
pad_shape = 64
vocab_size = config.text_config.vocab_size
if model_id == "liuhaotian/llava-v1.6-34b":
# this one has 3 additional tokens, namely <|startoftext|>, <|endoftext|> and <image>
num_tokens = vocab_size + 3
else:
# this one has 2 additional tokens, namely <image> and <pad>
num_tokens = vocab_size + 2
model.resize_token_embeddings(num_tokens, pad_to_multiple_of=pad_shape)
model.language_model.model.embed_tokens.weight.data[vocab_size:] = torch.stack(
tuple(
(
dist.sample()
for _ in range(model.language_model.model.embed_tokens.weight.data[vocab_size:].shape[0])
)
),
dim=0,
)
model.language_model.lm_head.weight.data[vocab_size:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[vocab_size:].shape[0]))),
dim=0,
)

print(f"Saving model and processor for {model_id} to {pytorch_dump_folder_path}")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)

# Make space so we can load the model properly now.
del state_dict
gc.collect()

device = "cuda:2"
model.to(device)
# Load everything back for inference tests in float32 because prev script was written as that
# Though it's mostly loaded in fp16 as original weights are in fp16
model = LlavaNextForConditionalGeneration.from_pretrained(pytorch_dump_folder_path, device_map="auto")
processor = LlavaNextProcessor.from_pretrained(pytorch_dump_folder_path)
device = model.device

# prepare inputs
image = load_image()
Expand All @@ -182,6 +210,11 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nWhat is shown in this image? ASSISTANT:"
elif model_id == "liuhaotian/llava-v1.6-34b":
prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n"
elif model_id == "lmms-lab/llama3-llava-next-8b":
prompt = "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|><|start_header_id|>user<|end_header_id|>\n\n<image>\nWhat is shown in this image?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
elif model_id in ["lmms-lab/llava-next-72b", "lmms-lab/llava-next-110b"]:
prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|>\n<|im_start|>assistant\n"

inputs = processor(images=image, text=prompt, return_tensors="pt")

# verify inputs
Expand All @@ -194,8 +227,6 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
original_input_ids = torch.load(filepath, map_location="cpu")
# replace -200 by image_token_index (since we use token ID = 32000 for the image token)
original_input_ids[original_input_ids == -200] = image_token_index
print(tokenizer.decode([id for id in original_input_ids.tolist()[0] if id != -200]))

assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist()

elif model_id == "liuhaotian/llava-v1.6-34b":
Expand Down Expand Up @@ -243,6 +274,26 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
dtype=torch.float32,
device=device,
)
elif model_id == "lmms-lab/llama3-llava-next-8b":
expected_slice = torch.tensor(
[[-3.9648, 1.1396, 3.3145], [-5.3594, -1.5654, -1.9619], [-12.3750, -10.6797, -9.3125]],
dtype=torch.float32,
device=device,
)
elif model_id == "lmms-lab/llava-next-72b":
# Not yet checked against reference
expected_slice = torch.tensor(
[[3.7148, 3.9277, 3.4395], [-0.4341, 1.1387, 6.5117], [3.2324, 3.4688, 4.1133]],
dtype=torch.float32,
device=device,
)
elif model_id == "lmms-lab/llava-next-110b":
# Not yet checked against reference
expected_slice = torch.tensor(
[[-2.5449, -1.6738, -2.0371], [1.0811, 3.4961, 5.0312], [1.7803, 2.5137, 2.4277]],
dtype=torch.float32,
device=device,
)
else:
raise ValueError(f"Model {model_id} not supported")

Expand All @@ -268,6 +319,12 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
expected_text = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a radar chart, also known as a spider chart or star chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point.\n\nIn this particular radar chart, there are several variables represented:\n\n- MM-Vet\n- LLa-Va-Bench\n- SEED-Bench\n- MM"
elif model_id == "liuhaotian/llava-v1.6-34b":
expected_text = "<|im_start|> system\nAnswer the questions. <|im_start|> user\n\nWhat is shown in this image? <|im_start|> assistant\nThe image appears to be a radar chart, also known as a spider chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point.\n\nIn this particular chart, there are several datasets represented by different colors and labeled with various acronyms such as MM-Vet, LLaVA-Bench, SEED-Bench, MM-Bench-CN, MM-"
elif model_id == "lmms-lab/llama3-llava-next-8b":
expected_text = 'system\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.user\n\n\nWhat is shown in this image?assistant\n\n\nThe image shows a radar chart, also known as a spider chart or a web chart, which is a type of graph used to display multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. Each axis represents a different variable, and the values are plotted along each axis and connected to form a polygon.\n\nIn this particular radar chart, there are several axes labeled with different variables, such as "MM-Vet," "LL'
elif model_id == "lmms-lab/llava-next-72b":
expected_text = "system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThe image displays a radar chart, also known as a spider chart or a star chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. Each axis represents a different variable, and the value of each variable is represented by the distance from the center of the chart to the point where the axis intersects with the line representing that variable's value.\n\nIn this particular chart, there are several axes"
elif model_id == "lmms-lab/llava-next-110b":
expected_text = "system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThe image shows a radar chart comparing the performance of different models on various visual question answering (VQA) benchmarks. Each colored line represents a different model, and the distance from the center of the chart indicates the score or performance level of the model on a particular benchmark. The benchmarks are labeled around the edges of the chart, and include VQA v2, GQA, VizWiz, TextVQA, MMBench-CN, MME, and others. The chart allows for a"
else:
raise ValueError(f"Model {model_id} not supported")

Expand All @@ -281,7 +338,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):

inputs = processor(
images=[image, cats_image],
text=[prompt, "[INST] <image>\nHow many cats are there? [/INST]"],
text=[prompt, prompt],
padding=True,
return_tensors="pt",
).to(device)
Expand All @@ -305,16 +362,11 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
print(outputs)

if pytorch_dump_folder_path is not None:
print(f"Saving model and processor for {model_id} to {pytorch_dump_folder_path}")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)

if push_to_hub:
repo_id = model_id.split("/")[-1]
model.push_to_hub(f"llava-hf/{repo_id}-hf")
processor.push_to_hub(f"llava-hf/{repo_id}-hf")
checkpoint_name = model_id.split("/")[-1]
print(f"Pushing to repo llava-hf/{checkpoint_name}-hf")
model.push_to_hub(f"llava-hf/{checkpoint_name}-hf")
processor.push_to_hub(f"llava-hf/{checkpoint_name}-hf")


if __name__ == "__main__":
Expand All @@ -328,11 +380,14 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
"liuhaotian/llava-v1.6-vicuna-7b",
"liuhaotian/llava-v1.6-vicuna-13b",
"liuhaotian/llava-v1.6-34b",
"lmms-lab/llama3-llava-next-8b",
"lmms-lab/llava-next-72b",
"lmms-lab/llava-next-110b",
],
required=False,
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
"--pytorch_dump_folder_path", type=str, required=True, help="Path to the output PyTorch model directory."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why make this required?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to save the model and then load again, so that it's loaded by automatically inferring how to split weights on each device/cpu/disk depending on how much memory available.

Especially required in big models that we're adding, they need multi-gpu inference but we can't init model from config in multi-gpu setting

)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
Expand Down
Loading