Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Awq] Add llava fused modules support #28239

Merged
merged 3 commits into from
Jan 12, 2024

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Dec 25, 2023

What does this PR do?

This PR adds the Llava + fused modules support for blazing fast text generation using Llava + AWQ!

This PR also fixes the issue: #28032 (comment) pointed out by a user since a custom past key value is passed to the model, indeed filtering out indexes that are inside the range of extended_attention_mask fixes the issue.

Added also a slow test

Can also confirm all Llava slow tests pass!

cc @casper-hansen

@cbjtu
Copy link

cbjtu commented Dec 26, 2023

Thank you soooo much, this PR and #28032 helped me work well now!

@ArthurZucker ArthurZucker self-requested a review January 2, 2024 10:00
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! More discussion needed for the attended / non attended tokens is needed IMO! 🤗

src/transformers/integrations/awq.py Show resolved Hide resolved
Comment on lines +3577 to +3584
# In case a user passes a `AwqConfig` with `do_fuse=True` for models that have
# a `modules_to_not_convert` attribute we need to manually set that attribute into the
# passed `quantization_config`
elif (
quantization_config.modules_to_not_convert is None
and "modules_to_not_convert" in config.quantization_config
):
quantization_config.modules_to_not_convert = config.quantization_config["modules_to_not_convert"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that seems a bit odd to me, should either be done in the integration (I know you don't have access to the config) or when you init the quantization_config, you should use config.quantization_config no? (at some point merging kwargs?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we always rely on quantization_config, we do merge kwargs but on the other way around (from quantization_config to config.quantization_config) with get_loading_attributes(). The scenario above happens only with the specific case where users pass do_fuse=True & a non-None value in config.quantization_config["modules_to_not_convert"]. I think it is a good idea to think of a way to harmonize how to merge kwargs between config.quantization_config and quantization_config but might be slightly out of the scope of the PR as I need to do it for all quantization schemes we support. I propose to do that properly in a follow up PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright let's keep that in mind

src/transformers/models/llava/modeling_llava.py Outdated Show resolved Hide resolved
Comment on lines +451 to +453
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a fan of adding custom code only to handle custom usages. There should be a more general way of handling these things (why use the extended attention mask and not just the attention mask, why not use the past key value length, etc)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not to handle custom usage, it happens when a past key value with padd tokens are on indices that are larger than the extended attention mask shape: #28032 (comment) & #28239 (comment) - this can mainly happen in batched generation with long seq len and it specifically happens for autoawq fused modules because the dummy past key values are initialized will all zeros: https://github.com/casper-hansen/AutoAWQ/blob/a3db8099a234a46a21bf5e46340da60da6992e0c/awq/modules/fused/attn.py#L238
In any case I don't think this will cause any harm since it just filers out indices of padd tokens (that are not attended anyway) that are out of the extended attention mask range, and I confirmed all slow tests pass

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I though this was already solved. No worries them, it's just that tensor indexing might slow things down a bit but is required anyway. I think a refactor might help:

  • Init the embeddings with a different value (like -1 which is might not happen as often as zeros) when we compute the image indexes
  • correctly update the attention mask when merging to make sure we keep track of what we computed
    I'd be in favor of moving this fix to another PR maybe? WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see thanks !

I am happy to explore:

Init the embeddings with a different value (like -1 which is might not happen as often as zeros) when we compute the image indexes

In another PR !

I'd be in favor of moving this fix to another PR maybe? WDYT?

That might be not ideal because if this fix is not introduced, users cannot run llava + fused modules :/ I'll address the points you shared in a follow up PR !

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

Comment on lines +451 to +453
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I though this was already solved. No worries them, it's just that tensor indexing might slow things down a bit but is required anyway. I think a refactor might help:

  • Init the embeddings with a different value (like -1 which is might not happen as often as zeros) when we compute the image indexes
  • correctly update the attention mask when merging to make sure we keep track of what we computed
    I'd be in favor of moving this fix to another PR maybe? WDYT?

@younesbelkada
Copy link
Contributor Author

Thanks for your reviews @ArthurZucker ! Merging ! I'll address the points you shared in #28239 (comment) in another PR as stated in my reply

@younesbelkada younesbelkada merged commit 07bdbeb into huggingface:main Jan 12, 2024
21 checks passed
@younesbelkada younesbelkada deleted the llava-fused-modules branch January 12, 2024 05:55
staghado pushed a commit to staghado/transformers that referenced this pull request Jan 15, 2024
* add llava + fused modules

* Update src/transformers/models/llava/modeling_llava.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
MadElf1337 pushed a commit to MadElf1337/transformers that referenced this pull request Jan 15, 2024
* add llava + fused modules

* Update src/transformers/models/llava/modeling_llava.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
wgifford pushed a commit to wgifford/transformers that referenced this pull request Jan 21, 2024
* add llava + fused modules

* Update src/transformers/models/llava/modeling_llava.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
AjayP13 pushed a commit to AjayP13/transformers that referenced this pull request Jan 22, 2024
* add llava + fused modules

* Update src/transformers/models/llava/modeling_llava.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants