From d8597d58291b7332409d8fffdeba90eb971aeace Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Wed, 12 Jun 2024 21:37:12 +0000 Subject: [PATCH 01/10] Add llama3-llava-next-8b to llava_next conversion script Adds support for the lmms-lab/llama3-llava-next-8b model to the convert_llava_next_weights_to_hf.py script, along with an example prompt generated from the llava_llama_3 conv_template in the LLaVA-NeXT repo. --- .../convert_llava_next_weights_to_hf.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py index 2c8aefe39dc255..432f64191440c9 100644 --- a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py +++ b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py @@ -111,6 +111,10 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): elif model_id == "liuhaotian/llava-v1.6-34b": text_model_id = "NousResearch/Nous-Hermes-2-Yi-34B" image_token_index = 64000 + elif model_id == "lmms-lab/llama3-llava-next-8b": + text_model_id = "meta-llama/Meta-Llama-3-8B-Instruct" + image_token_index = 128256 + vision_model_id = data["mm_vision_tower"] torch.set_default_dtype(torch.float16) @@ -120,7 +124,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): tokenizer = AutoTokenizer.from_pretrained(text_model_id, use_fast=use_fast) tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special_tokens=True) - if model_id == "liuhaotian/llava-v1.6-mistral-7b": + if model_id == "liuhaotian/llava-v1.6-mistral-7b" or model_id == "lmms-lab/llama3-llava-next-8b": # Mistral-7B doesn't have a padding token set yet tokenizer.add_special_tokens({"pad_token": ""}) @@ -182,6 +186,9 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nWhat is shown in this image? ASSISTANT:" elif model_id == "liuhaotian/llava-v1.6-34b": prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n" + elif model_id == "lmms-lab/llama3-llava-next-8b": + prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|><|start_header_id|>user<|end_header_id|>\n\n\nWhat is shown in this image?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + inputs = processor(images=image, text=prompt, return_tensors="pt") # verify inputs @@ -243,6 +250,12 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): dtype=torch.float32, device=device, ) + elif model_id == "lmms-lab/llama3-llava-next-8b": + expected_slice = torch.tensor( + [[-3.9648, 1.1396, 3.3145], [-3.9648, 1.1396, 3.3145], [-3.6426, -0.0081, -0.1266]], + dtype=torch.float32, + device=device + ) else: raise ValueError(f"Model {model_id} not supported") @@ -268,6 +281,8 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): expected_text = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a radar chart, also known as a spider chart or star chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point.\n\nIn this particular radar chart, there are several variables represented:\n\n- MM-Vet\n- LLa-Va-Bench\n- SEED-Bench\n- MM" elif model_id == "liuhaotian/llava-v1.6-34b": expected_text = "<|im_start|> system\nAnswer the questions. <|im_start|> user\n\nWhat is shown in this image? <|im_start|> assistant\nThe image appears to be a radar chart, also known as a spider chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point.\n\nIn this particular chart, there are several datasets represented by different colors and labeled with various acronyms such as MM-Vet, LLaVA-Bench, SEED-Bench, MM-Bench-CN, MM-" + elif model_id == "lmms-lab/llama3-llava-next-8b": + expected_text = 'system\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.user\n\n\nWhat is shown in this image?assistant\n\n\nThe image appears to be a radar chart, also known as a spider chart or a web chart, which is a type of graph that displays multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. Each axis represents a different variable, and the values are plotted along each axis and connected to form a polygon.\n\nIn this particular radar chart, there are several axes labeled with different variables, such as "MM-Vet," "' else: raise ValueError(f"Model {model_id} not supported") @@ -328,6 +343,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): "liuhaotian/llava-v1.6-vicuna-7b", "liuhaotian/llava-v1.6-vicuna-13b", "liuhaotian/llava-v1.6-34b", + "lmms-lab/llama3-llava-next-8b", ], required=False, ) From 57413fd42373d314145b94bbe3ed848b83d34135 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sat, 15 Jun 2024 19:39:03 +0000 Subject: [PATCH 02/10] Exclude <|begin_of_text|> from prompt example This token gets added automatically, so it should not be included in the prompt example. --- .../models/llava_next/convert_llava_next_weights_to_hf.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py index 432f64191440c9..9139a7c6ea7776 100644 --- a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py +++ b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py @@ -187,7 +187,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): elif model_id == "liuhaotian/llava-v1.6-34b": prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n" elif model_id == "lmms-lab/llama3-llava-next-8b": - prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|><|start_header_id|>user<|end_header_id|>\n\n\nWhat is shown in this image?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + prompt = "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|><|start_header_id|>user<|end_header_id|>\n\n\nWhat is shown in this image?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" inputs = processor(images=image, text=prompt, return_tensors="pt") @@ -252,9 +252,9 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): ) elif model_id == "lmms-lab/llama3-llava-next-8b": expected_slice = torch.tensor( - [[-3.9648, 1.1396, 3.3145], [-3.9648, 1.1396, 3.3145], [-3.6426, -0.0081, -0.1266]], + [[-3.9648, 1.1396, 3.3145], [-5.3594, -1.5654, -1.9619], [-12.3750, -10.6797, -9.3125]], dtype=torch.float32, - device=device + device=device, ) else: raise ValueError(f"Model {model_id} not supported") @@ -282,7 +282,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): elif model_id == "liuhaotian/llava-v1.6-34b": expected_text = "<|im_start|> system\nAnswer the questions. <|im_start|> user\n\nWhat is shown in this image? <|im_start|> assistant\nThe image appears to be a radar chart, also known as a spider chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point.\n\nIn this particular chart, there are several datasets represented by different colors and labeled with various acronyms such as MM-Vet, LLaVA-Bench, SEED-Bench, MM-Bench-CN, MM-" elif model_id == "lmms-lab/llama3-llava-next-8b": - expected_text = 'system\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.user\n\n\nWhat is shown in this image?assistant\n\n\nThe image appears to be a radar chart, also known as a spider chart or a web chart, which is a type of graph that displays multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. Each axis represents a different variable, and the values are plotted along each axis and connected to form a polygon.\n\nIn this particular radar chart, there are several axes labeled with different variables, such as "MM-Vet," "' + expected_text = 'system\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.user\n\n\nWhat is shown in this image?assistant\n\n\nThe image shows a radar chart, also known as a spider chart or a web chart, which is a type of graph used to display multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. Each axis represents a different variable, and the values are plotted along each axis and connected to form a polygon.\n\nIn this particular radar chart, there are several axes labeled with different variables, such as "MM-Vet," "LL' else: raise ValueError(f"Model {model_id} not supported") From 73678672eac8333f8175b2a704d67e8e60ae89f5 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sun, 16 Jun 2024 17:00:34 +0000 Subject: [PATCH 03/10] Add llava-next-72b and llava-next-110b Adds the Qwen-based LLaVA-Next models to the conversion script, along with changes to load the models on multiple GPUs for inference. --- .../convert_llava_next_weights_to_hf.py | 78 ++++++++++++++----- 1 file changed, 57 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py index 9139a7c6ea7776..46dc326b449ef4 100644 --- a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py +++ b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py @@ -114,6 +114,12 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): elif model_id == "lmms-lab/llama3-llava-next-8b": text_model_id = "meta-llama/Meta-Llama-3-8B-Instruct" image_token_index = 128256 + elif model_id == "lmms-lab/llava-next-72b": + text_model_id = "Qwen/Qwen1.5-72B-Chat" + image_token_index = 151646 + elif model_id == "lmms-lab/llava-next-110b": + text_model_id = "Qwen/Qwen1.5-110B-Chat" + image_token_index = 151646 vision_model_id = data["mm_vision_tower"] @@ -155,28 +161,37 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): # We add an image token so we resize the model # Pad to 64 for performance reasons - pad_shape = 64 - vocab_size = config.text_config.vocab_size - if model_id == "liuhaotian/llava-v1.6-34b": - # this one has 3 additional tokens, namely <|startoftext|>, <|endoftext|> and - num_tokens = vocab_size + 3 - else: - # this one has 2 additional tokens, namely and - num_tokens = vocab_size + 2 - model.resize_token_embeddings(num_tokens, pad_to_multiple_of=pad_shape) - model.language_model.model.embed_tokens.weight.data[vocab_size:] = torch.stack( - tuple( - (dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[vocab_size:].shape[0])) - ), - dim=0, - ) - model.language_model.lm_head.weight.data[vocab_size:] = torch.stack( - tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[vocab_size:].shape[0]))), - dim=0, - ) + # Qwen-based models have extra unused space in the vocab size already, so no need to resize + if model_id not in ["lmms-lab/llava-next-72b", "lmms-lab/llava-next-110b"]: + pad_shape = 64 + vocab_size = config.text_config.vocab_size + if model_id == "liuhaotian/llava-v1.6-34b": + # this one has 3 additional tokens, namely <|startoftext|>, <|endoftext|> and + num_tokens = vocab_size + 3 + else: + # this one has 2 additional tokens, namely and + num_tokens = vocab_size + 2 + model.resize_token_embeddings(num_tokens, pad_to_multiple_of=pad_shape) + model.language_model.model.embed_tokens.weight.data[vocab_size:] = torch.stack( + tuple( + (dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[vocab_size:].shape[0])) + ), + dim=0, + ) + model.language_model.lm_head.weight.data[vocab_size:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[vocab_size:].shape[0]))), + dim=0, + ) - device = "cuda:2" - model.to(device) + if model_id in ["lmms-lab/llava-next-72b", "lmms-lab/llava-next-110b"]: + # For these big models need to do multi-gpu inference, so reload the model with device_map="auto" in order to use accelerate + # Is there a way to do this without saving and reloading the model? + model.save_pretrained("/tmp/llava_qwen") + model = LlavaNextForConditionalGeneration.from_pretrained("/tmp/llava_qwen", device_map="auto") + device = "cuda" + else: + device = "cuda:2" + model.to(device) # prepare inputs image = load_image() @@ -188,6 +203,8 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n" elif model_id == "lmms-lab/llama3-llava-next-8b": prompt = "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|><|start_header_id|>user<|end_header_id|>\n\n\nWhat is shown in this image?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + elif model_id in ["lmms-lab/llava-next-72b", "lmms-lab/llava-next-110b"]: + prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n\nWhat is shown in this image?<|im_end|>\n<|im_start|>assistant\n" inputs = processor(images=image, text=prompt, return_tensors="pt") @@ -256,6 +273,19 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): dtype=torch.float32, device=device, ) + elif model_id == "lmms-lab/llava-next-72b": + # Not yet checked against reference + expected_slice = torch.tensor( + [[3.7148, 3.9277, 3.4395], [-0.4341, 1.1387, 6.5117], [3.2324, 3.4688, 4.1133]], + dtype=torch.float32, + device=device, + ) + elif model_id == "lmms-lab/llava-next-110b": + # Not yet checked against reference + expected_slice = torch.tensor([[-2.5449, -1.6738, -2.0371], [1.0811, 3.4961, 5.0312], [1.7803, 2.5137, 2.4277]], + dtype=torch.float32, + device=device, + ) else: raise ValueError(f"Model {model_id} not supported") @@ -283,6 +313,10 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): expected_text = "<|im_start|> system\nAnswer the questions. <|im_start|> user\n\nWhat is shown in this image? <|im_start|> assistant\nThe image appears to be a radar chart, also known as a spider chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point.\n\nIn this particular chart, there are several datasets represented by different colors and labeled with various acronyms such as MM-Vet, LLaVA-Bench, SEED-Bench, MM-Bench-CN, MM-" elif model_id == "lmms-lab/llama3-llava-next-8b": expected_text = 'system\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.user\n\n\nWhat is shown in this image?assistant\n\n\nThe image shows a radar chart, also known as a spider chart or a web chart, which is a type of graph used to display multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. Each axis represents a different variable, and the values are plotted along each axis and connected to form a polygon.\n\nIn this particular radar chart, there are several axes labeled with different variables, such as "MM-Vet," "LL' + elif model_id == "lmms-lab/llava-next-72b": + expected_text = "system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThe image displays a radar chart, also known as a spider chart or a star chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. Each axis represents a different variable, and the value of each variable is represented by the distance from the center of the chart to the point where the axis intersects with the line representing that variable's value.\n\nIn this particular chart, there are several axes" + elif model_id == "lmms-lab/llava-next-110b": + expected_text = 'system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThe image shows a radar chart comparing the performance of different models on various visual question answering (VQA) benchmarks. Each colored line represents a different model, and the distance from the center of the chart indicates the score or performance level of the model on a particular benchmark. The benchmarks are labeled around the edges of the chart, and include VQA v2, GQA, VizWiz, TextVQA, MMBench-CN, MME, and others. The chart allows for a' else: raise ValueError(f"Model {model_id} not supported") @@ -344,6 +378,8 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): "liuhaotian/llava-v1.6-vicuna-13b", "liuhaotian/llava-v1.6-34b", "lmms-lab/llama3-llava-next-8b", + "lmms-lab/llava-next-72b", + "lmms-lab/llava-next-110b" ], required=False, ) From 5146563675a5fe121cfa00c14eca9542e1ce1bb8 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sun, 30 Jun 2024 20:58:18 +0000 Subject: [PATCH 04/10] Add llama3 and qwen prompt formats to docs --- docs/source/en/model_doc/llava_next.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/source/en/model_doc/llava_next.md b/docs/source/en/model_doc/llava_next.md index a4a1419ee00ac8..1e6bf66818692a 100644 --- a/docs/source/en/model_doc/llava_next.md +++ b/docs/source/en/model_doc/llava_next.md @@ -66,6 +66,19 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/ "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n" ``` +llama3-llava-next-8b-hf requires the following format: + +```bash +"<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|><|start_header_id|>user<|end_header_id|>\n\n\nWhat is shown in this image?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" +``` + +llava-next-72b-hf and llava-next-110b-hf require the following format: + +```bash +"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n\nWhat is shown in this image?<|im_end|>\n<|im_start|>assistant\n" +``` + + ## Usage example ### Single image inference From b115763f0ed677bff817f00484b7631648b6f0dc Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Sun, 30 Jun 2024 22:00:19 +0000 Subject: [PATCH 05/10] Chat prompt and padding side left for llama3 batched --- .../convert_llava_next_weights_to_hf.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py index 46dc326b449ef4..47e6ac95b03032 100644 --- a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py +++ b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py @@ -174,7 +174,10 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): model.resize_token_embeddings(num_tokens, pad_to_multiple_of=pad_shape) model.language_model.model.embed_tokens.weight.data[vocab_size:] = torch.stack( tuple( - (dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[vocab_size:].shape[0])) + ( + dist.sample() + for _ in range(model.language_model.model.embed_tokens.weight.data[vocab_size:].shape[0]) + ) ), dim=0, ) @@ -282,7 +285,8 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): ) elif model_id == "lmms-lab/llava-next-110b": # Not yet checked against reference - expected_slice = torch.tensor([[-2.5449, -1.6738, -2.0371], [1.0811, 3.4961, 5.0312], [1.7803, 2.5137, 2.4277]], + expected_slice = torch.tensor( + [[-2.5449, -1.6738, -2.0371], [1.0811, 3.4961, 5.0312], [1.7803, 2.5137, 2.4277]], dtype=torch.float32, device=device, ) @@ -316,7 +320,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): elif model_id == "lmms-lab/llava-next-72b": expected_text = "system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThe image displays a radar chart, also known as a spider chart or a star chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. Each axis represents a different variable, and the value of each variable is represented by the distance from the center of the chart to the point where the axis intersects with the line representing that variable's value.\n\nIn this particular chart, there are several axes" elif model_id == "lmms-lab/llava-next-110b": - expected_text = 'system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThe image shows a radar chart comparing the performance of different models on various visual question answering (VQA) benchmarks. Each colored line represents a different model, and the distance from the center of the chart indicates the score or performance level of the model on a particular benchmark. The benchmarks are labeled around the edges of the chart, and include VQA v2, GQA, VizWiz, TextVQA, MMBench-CN, MME, and others. The chart allows for a' + expected_text = "system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThe image shows a radar chart comparing the performance of different models on various visual question answering (VQA) benchmarks. Each colored line represents a different model, and the distance from the center of the chart indicates the score or performance level of the model on a particular benchmark. The benchmarks are labeled around the edges of the chart, and include VQA v2, GQA, VizWiz, TextVQA, MMBench-CN, MME, and others. The chart allows for a" else: raise ValueError(f"Model {model_id} not supported") @@ -328,9 +332,15 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): url = "http://images.cocodataset.org/val2017/000000039769.jpg" cats_image = Image.open(requests.get(url, stream=True).raw) + if model_id == "lmms-lab/llama3-llava-next-8b": + cats_prompt = "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|><|start_header_id|>user<|end_header_id|>\n\n\nHow many cats are there?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + processor.tokenizer.padding_side = "left" + else: + cats_prompt = "[INST] \nHow many cats are there? [/INST]" + inputs = processor( images=[image, cats_image], - text=[prompt, "[INST] \nHow many cats are there? [/INST]"], + text=[prompt, cats_prompt], padding=True, return_tensors="pt", ).to(device) @@ -379,7 +389,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): "liuhaotian/llava-v1.6-34b", "lmms-lab/llama3-llava-next-8b", "lmms-lab/llava-next-72b", - "lmms-lab/llava-next-110b" + "lmms-lab/llava-next-110b", ], required=False, ) From fb1e50138bf1b05c54aeb1b5d86b17f569722531 Mon Sep 17 00:00:00 2001 From: raushan Date: Fri, 19 Jul 2024 12:34:36 +0200 Subject: [PATCH 06/10] update --- .../convert_llava_next_weights_to_hf.py | 67 +++++++++++-------- 1 file changed, 39 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py index 47e6ac95b03032..a0bbf3ce94b58f 100644 --- a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py +++ b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py @@ -24,6 +24,7 @@ """ import argparse +import gc import glob import json from pathlib import Path @@ -186,15 +187,22 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): dim=0, ) - if model_id in ["lmms-lab/llava-next-72b", "lmms-lab/llava-next-110b"]: - # For these big models need to do multi-gpu inference, so reload the model with device_map="auto" in order to use accelerate - # Is there a way to do this without saving and reloading the model? - model.save_pretrained("/tmp/llava_qwen") - model = LlavaNextForConditionalGeneration.from_pretrained("/tmp/llava_qwen", device_map="auto") - device = "cuda" - else: - device = "cuda:2" - model.to(device) + print(f"Saving model and processor for {model_id} to {pytorch_dump_folder_path}") + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + # Make space so we can load the model properly now. + del state_dict + gc.collect() + + # Load everything back for inference tests in float23 because prev script was written as that + # Though it's mostly loaded in fp16 as original weights are in fp16 + model = LlavaNextForConditionalGeneration.from_pretrained( + pytorch_dump_folder_path, torch_dtype=torch.float32, device_map="auto" + ) + processor = LlavaNextProcessor.from_pretrained(pytorch_dump_folder_path) + device = model.device # prepare inputs image = load_image() @@ -221,8 +229,6 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): original_input_ids = torch.load(filepath, map_location="cpu") # replace -200 by image_token_index (since we use token ID = 32000 for the image token) original_input_ids[original_input_ids == -200] = image_token_index - print(tokenizer.decode([id for id in original_input_ids.tolist()[0] if id != -200])) - assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist() elif model_id == "liuhaotian/llava-v1.6-34b": @@ -332,15 +338,9 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): url = "http://images.cocodataset.org/val2017/000000039769.jpg" cats_image = Image.open(requests.get(url, stream=True).raw) - if model_id == "lmms-lab/llama3-llava-next-8b": - cats_prompt = "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|><|start_header_id|>user<|end_header_id|>\n\n\nHow many cats are there?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" - processor.tokenizer.padding_side = "left" - else: - cats_prompt = "[INST] \nHow many cats are there? [/INST]" - inputs = processor( images=[image, cats_image], - text=[prompt, cats_prompt], + text=[prompt, prompt], padding=True, return_tensors="pt", ).to(device) @@ -364,16 +364,27 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) print(outputs) - if pytorch_dump_folder_path is not None: - print(f"Saving model and processor for {model_id} to {pytorch_dump_folder_path}") - Path(pytorch_dump_folder_path).mkdir(exist_ok=True) - model.save_pretrained(pytorch_dump_folder_path) - processor.save_pretrained(pytorch_dump_folder_path) + import os + + from huggingface_hub import HfApi + + repo_id = model_id.split("/")[-1] + print(f"Pushing to repo llava-hf/{repo_id}-hf") + + api = HfApi() + for file in os.listdir(pytorch_dump_folder_path): + api.upload_file( + path_or_fileobj=f"{pytorch_dump_folder_path}/{file}", + path_in_repo=file, + repo_id=f"llava-hf/{repo_id}-hf", + repo_type="model", + ) - if push_to_hub: - repo_id = model_id.split("/")[-1] - model.push_to_hub(f"llava-hf/{repo_id}-hf") - processor.push_to_hub(f"llava-hf/{repo_id}-hf") + # if push_to_hub: + # repo_id = model_id.split("/")[-1] + # print(f"Pushing to repo llava-hf/{repo_id}-hf") + # model.push_to_hub(f"llava-hf/{repo_id}-hf") + # processor.push_to_hub(f"llava-hf/{repo_id}-hf") if __name__ == "__main__": @@ -394,7 +405,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): required=False, ) parser.add_argument( - "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + "--pytorch_dump_folder_path", type=str, required=True, help="Path to the output PyTorch model directory." ) parser.add_argument( "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." From 79e5a555310631e4b33d23a0bfd0d70b9b08c780 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Mon, 22 Jul 2024 09:59:55 +0500 Subject: [PATCH 07/10] Update src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../models/llava_next/convert_llava_next_weights_to_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py index a0bbf3ce94b58f..ef5d5913aac7ef 100644 --- a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py +++ b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py @@ -131,7 +131,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): tokenizer = AutoTokenizer.from_pretrained(text_model_id, use_fast=use_fast) tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special_tokens=True) - if model_id == "liuhaotian/llava-v1.6-mistral-7b" or model_id == "lmms-lab/llama3-llava-next-8b": + if model_id in ("liuhaotian/llava-v1.6-mistral-7b", "lmms-lab/llama3-llava-next-8b"): # Mistral-7B doesn't have a padding token set yet tokenizer.add_special_tokens({"pad_token": ""}) From 9b6ba320fdf197db927f0a14aedc1ebbcfaea2b6 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Mon, 22 Jul 2024 10:07:05 +0500 Subject: [PATCH 08/10] Update src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../models/llava_next/convert_llava_next_weights_to_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py index ef5d5913aac7ef..48f08b2bb70b90 100644 --- a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py +++ b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py @@ -196,7 +196,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): del state_dict gc.collect() - # Load everything back for inference tests in float23 because prev script was written as that + # Load everything back for inference tests in float32 because prev script was written as that # Though it's mostly loaded in fp16 as original weights are in fp16 model = LlavaNextForConditionalGeneration.from_pretrained( pytorch_dump_folder_path, torch_dtype=torch.float32, device_map="auto" From ab028a10ed5500d21ff11734e14b4d9f1a421d3b Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 22 Jul 2024 07:11:54 +0200 Subject: [PATCH 09/10] remove code --- .../convert_llava_next_weights_to_hf.py | 30 ++++--------------- 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py index 48f08b2bb70b90..ba52ca04f49688 100644 --- a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py +++ b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py @@ -198,9 +198,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): # Load everything back for inference tests in float32 because prev script was written as that # Though it's mostly loaded in fp16 as original weights are in fp16 - model = LlavaNextForConditionalGeneration.from_pretrained( - pytorch_dump_folder_path, torch_dtype=torch.float32, device_map="auto" - ) + model = LlavaNextForConditionalGeneration.from_pretrained(pytorch_dump_folder_path, device_map="auto") processor = LlavaNextProcessor.from_pretrained(pytorch_dump_folder_path) device = model.device @@ -364,27 +362,11 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) print(outputs) - import os - - from huggingface_hub import HfApi - - repo_id = model_id.split("/")[-1] - print(f"Pushing to repo llava-hf/{repo_id}-hf") - - api = HfApi() - for file in os.listdir(pytorch_dump_folder_path): - api.upload_file( - path_or_fileobj=f"{pytorch_dump_folder_path}/{file}", - path_in_repo=file, - repo_id=f"llava-hf/{repo_id}-hf", - repo_type="model", - ) - - # if push_to_hub: - # repo_id = model_id.split("/")[-1] - # print(f"Pushing to repo llava-hf/{repo_id}-hf") - # model.push_to_hub(f"llava-hf/{repo_id}-hf") - # processor.push_to_hub(f"llava-hf/{repo_id}-hf") + if push_to_hub: + repo_id = model_id.split("/")[-1] + print(f"Pushing to repo llava-hf/{repo_id}-hf") + model.push_to_hub(f"llava-hf/{repo_id}-hf") + processor.push_to_hub(f"llava-hf/{repo_id}-hf") if __name__ == "__main__": From b4c645a2cb650e25f680b970e238f5e9ad4e04f1 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 22 Jul 2024 07:13:44 +0200 Subject: [PATCH 10/10] better naming --- .../models/llava_next/convert_llava_next_weights_to_hf.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py index ba52ca04f49688..06edc5c9b1adbc 100644 --- a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py +++ b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py @@ -363,10 +363,10 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): print(outputs) if push_to_hub: - repo_id = model_id.split("/")[-1] - print(f"Pushing to repo llava-hf/{repo_id}-hf") - model.push_to_hub(f"llava-hf/{repo_id}-hf") - processor.push_to_hub(f"llava-hf/{repo_id}-hf") + checkpoint_name = model_id.split("/")[-1] + print(f"Pushing to repo llava-hf/{checkpoint_name}-hf") + model.push_to_hub(f"llava-hf/{checkpoint_name}-hf") + processor.push_to_hub(f"llava-hf/{checkpoint_name}-hf") if __name__ == "__main__":