Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core / attention] Fix fused attention generation with newest transformers version #146

Merged
merged 4 commits into from
Nov 3, 2023

Conversation

younesbelkada
Copy link
Collaborator

What does this PR do?

Currently in the latest transformers release, using AutoAWQ + fused attention with cache is broken
In huggingface/transformers#25242 the logic of caching has changed a bit, now when using transformers cache + a past key value length of 1 (as done here), the input ids will be sliced as such:

input_ids = input_ids[:, 1:]

Meaning the assumption if seqlen == 1: to deal with the transformers cache case needs now to be adapted, one can just check if past_key_values is present in kwargs and contains valid tensors, and slice out only the last token if that's the case

cc @casper-hansen

@younesbelkada
Copy link
Collaborator Author

I also checked out to this commit in transformers: huggingface/transformers#26162 (before huggingface/transformers#25242) and can confirm it works in both cases.

@casper-hansen
Copy link
Owner

Tested and looks good. No performance regression on my end.

@casper-hansen casper-hansen merged commit 92a403b into main Nov 3, 2023
@casper-hansen
Copy link
Owner

I take the remark about performance regression back. I tested using my benchmark.py script found in examples and saw no difference. But using the .generate() function, it is 50% slower.

Slicing the hidden states in every attention layer for every token is a lot of overhead. We should instead slice it at a higher level, e.g. in the model. However, that requires implementing a LlamaModel, MistralModel, AquilaModel. This is probably the right solution but requires a bit of work, which I will look into.

@younesbelkada
Copy link
Collaborator Author

Thanks for benchmarking ! yes slicing only once at the model level makes sense!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants