Skip to content

Commit

Permalink
[Awq] Add llava fused modules support (huggingface#28239)
Browse files Browse the repository at this point in the history
* add llava + fused modules

* Update src/transformers/models/llava/modeling_llava.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
  • Loading branch information
2 people authored and MadElf1337 committed Jan 15, 2024
1 parent a772000 commit 7563986
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 7 deletions.
35 changes: 30 additions & 5 deletions src/transformers/integrations/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
"use_alibi": False,
},
"llava": {
"attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
"mlp": ["gate_proj", "up_proj", "down_proj"],
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
"use_alibi": False,
},
}


Expand Down Expand Up @@ -143,10 +149,16 @@ def get_modules_to_fuse(model, quantization_config):
elif model.config.model_type in AWQ_FUSED_MAPPINGS:
current_fused_mapping = AWQ_FUSED_MAPPINGS[model.config.model_type]

# Properly deal with the case where we have a multi-modal model as well (e.g. Llava)
if not hasattr(model.config, "text_config"):
config = model.config
else:
config = model.config.text_config

# Handle hidden_size, num_attention_heads, num_key_value_heads on our own.
hidden_size = model.config.hidden_size
num_attention_heads = model.config.num_attention_heads
num_key_value_heads = getattr(model.config, "num_key_value_heads", num_attention_heads)
hidden_size = config.hidden_size
num_attention_heads = config.num_attention_heads
num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads)

# Fill `current_fused_mapping` with the expected values
current_fused_mapping["hidden_size"] = hidden_size
Expand Down Expand Up @@ -178,6 +190,7 @@ def fuse_awq_modules(model, quantization_config):
backend = awq_config.backend

modules_to_fuse = get_modules_to_fuse(model, awq_config)
modules_to_not_convert = getattr(awq_config, "modules_to_not_convert", None)

if backend == AwqBackendPackingMethod.AUTOAWQ:
from awq.modules.fused.attn import QuantAttentionFused
Expand All @@ -187,6 +200,10 @@ def fuse_awq_modules(model, quantization_config):
raise ValueError("Fusing is only supported for the AutoAWQ backend")

for name, module in model.named_modules():
if modules_to_not_convert is not None:
if any(module_name_to_not_convert in name for module_name_to_not_convert in modules_to_not_convert):
continue

# Replace layer norms
_fuse_awq_layernorm(modules_to_fuse["layernorm"], module, FasterTransformerRMSNorm)

Expand Down Expand Up @@ -248,7 +265,14 @@ def _fuse_awq_mlp(model, current_module_name, fuse_module_names, module, target_
down_proj = getattr(module, fuse_module_names[2])

previous_device = gate_proj.qweight.device
activation_fn = ACT2FN[model.config.hidden_act]

# Deal also with the case model has `text_config` attribute
hidden_act = (
model.config.hidden_act
if not hasattr(model.config, "text_config")
else model.config.text_config.hidden_act
)
activation_fn = ACT2FN[hidden_act]
new_module = target_cls(gate_proj, down_proj, up_proj, activation_fn)

parent_name, child_name = current_module_name.rsplit(".", 1)
Expand Down Expand Up @@ -284,7 +308,6 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
if hasattr(module, modules_to_fuse["attention"][0]):
# First, we pack the QKV layers together
q_proj = getattr(module, modules_to_fuse["attention"][0])
previous_device = q_proj.qweight.device

if isinstance(q_proj, WQLinear_GEMV):
linear_target_cls = WQLinear_GEMV
Expand All @@ -295,6 +318,8 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
else:
raise ValueError("Unsupported q_proj type: {type(q_proj)}")

previous_device = q_proj.qweight.device

k_proj = getattr(module, modules_to_fuse["attention"][1])
v_proj = getattr(module, modules_to_fuse["attention"][2])
o_proj = getattr(module, modules_to_fuse["attention"][3])
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3583,6 +3583,14 @@ def from_pretrained(

if quantization_config is None:
quantization_config = AwqConfig.from_dict(config.quantization_config)
# In case a user passes a `AwqConfig` with `do_fuse=True` for models that have
# a `modules_to_not_convert` attribute we need to manually set that attribute into the
# passed `quantization_config`
elif (
quantization_config.modules_to_not_convert is None
and "modules_to_not_convert" in config.quantization_config
):
quantization_config.modules_to_not_convert = config.quantization_config["modules_to_not_convert"]

if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
Expand Down
9 changes: 8 additions & 1 deletion src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,15 @@ def forward(
device=attention_mask.device,
)

# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]

# Zero-out the places where we don't need to attend
extended_attention_mask[batch_index, non_attended_tokens] = 0
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0

attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
Expand Down
9 changes: 8 additions & 1 deletion src/transformers/models/vipllava/modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,15 @@ def forward(
device=attention_mask.device,
)

# Filter out only the tokens that can be un-attended, this can happen
# in the case one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]

# Zero-out the places where we don't need to attend
extended_attention_mask[batch_index, non_attended_tokens] = 0
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0

attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
Expand Down
26 changes: 26 additions & 0 deletions tests/quantization/autoawq/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ class AwqFusedTest(unittest.TestCase):
custom_mapping_model_id = "TheBloke/Yi-34B-AWQ"
custom_model_revision = "f1b2cd1b7459ceecfdc1fac5bb8725f13707c589"

multi_modal_model_name = "ybelkada/llava-1.5-7b-hf-awq"
multi_modal_model_code_revision = "ad108a50f5b9e681bdd7378409f57b7fa59a7442"

prompt = (
"You're standing on the surface of the Earth. "
"You walk one mile south, one mile west and one mile north. "
Expand Down Expand Up @@ -344,6 +347,29 @@ def test_generation_fused_batched(self):

self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION)

def test_generation_llava_fused(self):
from transformers import pipeline

quantization_config = AwqConfig(do_fuse=True, fuse_max_seq_len=2048)

pipe = pipeline(
"image-to-text",
model=self.multi_modal_model_name,
device=0,
model_kwargs={
"quantization_config": quantization_config,
},
revision=self.multi_modal_model_code_revision,
)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-neg.png"

prompt = "USER: <image>\nCan you please describe this image?\nASSISTANT:"

outputs = pipe(url, prompt=prompt, generate_kwargs={"max_new_tokens": 100})
EXPECTED_OUTPUT = "USER: \nCan you please describe this image?\nASSISTANT: The image features a brown and white cat sitting on a green surface, possibly a carpet or a grassy area. The cat is holding a red ball in its paws, seemingly playing with it. The cat appears to be focused on the ball, possibly preparing to play or just enjoying the toy."

self.assertEqual(outputs[0]["generated_text"], EXPECTED_OUTPUT)

@require_torch_multi_gpu
def test_generation_custom_model(self):
"""
Expand Down

0 comments on commit 7563986

Please sign in to comment.