-
Notifications
You must be signed in to change notification settings - Fork 26.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GPT2
] Add SDPA support
#31172
[GPT2
] Add SDPA support
#31172
Conversation
2357341
to
de45320
Compare
Training benchmark:
Inference benchmark:
|
I'm a bit surprised with the inference results (regarding memory) but I don't have anything else to add (to the code and/or docs). Tests seem to pass including the sdpa vs eager slow tests (at least locally). Aside from that, commit d963ad5 failed on the tests for the hub. This seems very weird to me and may need a separate investigation since my code shouldn't have affected the test (shown by the empty commit 91fe533 right afterwards which passes). Might be a bug in the trainer or somewhere along there. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this contribution and for adding the extensive benchmark ! I left one open question, what do you think ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some small nits but overall good job! 👍🏻
- only save _attn_implentation once - remove unnecessary comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you simply run the slow GPT2 tests to make sure everything is alright? As SDPA
will be the new default!
Did that commit 4811cb5 work or is it done locally only? @ArthurZucker |
Ok doesn't look like it ran in slow 😓 i guess i need a different commit msg or perms? |
@vasqu if you have access to a GPU you can run them locally with |
Ah ic, i will try and report back thx @younesbelkada; I don't have Flash Attention enabled tho (if that's an issue). Would require me to update my cuda which is a pain. |
That's shouldn't be an issue IMO, we just want to make sure the currently slow integration tests still pass with SDPA ! |
slow_test_gpt2_sample.txt When run all in one, three tests fail. Not sure why because when run in isolation only Edit: Seeing that contrastive search and sampling test fail due to too low VRAM 😆 Works in isolation tho. |
transformers/tests/models/gpt2/test_modeling_gpt2.py Lines 832 to 835 in 48d35b2
Should be changed with start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) That would follow the pattern of all assertions before that and then the tests also pass. |
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) | ||
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME)) | ||
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only failing test without this modification. I'm not sure if this is how it is intended. Maybe assertGreater
should be assertLess
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that this change is totally fine, now SDPA is used so the generation is faster
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for working on this !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! thanks for adding
# Attention mask. | ||
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
* `gpt2` sdpa support * fix (at least) one test, style, repo consistency * fix sdpa mask in forward --> fixes generation * test * test2 * test3 * test4 * simplify shapes for attn mask creation and small comments * hub fail test * benchmarks * flash attn 2 mask should not be inverted on enc-dec setup * fix comment * apply some suggestion from code review - only save _attn_implentation once - remove unnecessary comment * change elif logic * [run-slow] gpt2 * modify `test_gpt2_sample_max_time` to follow previous assertion patterns
What does this PR do?
Adds torch's SDPA to the GPT2 model architecture (as another attention module). Possibly relevant #28005.
I'll be adding some benchmarks like in #31031 for gpt2-large (same setup). Docs will be edited afterwards and I'll share my results in this thread. I've checked the two most important tests that validate sdpa vs eager:
RUN_SLOW=True pytest tests/models/gpt2 -k "test_eager_matches_sdpa_generate" -s -vvvvv
RUN_SLOW=True pytest tests/models/gpt2 -k "test_eager_matches_sdpa_inference" -s -vvvvv
Both pass without any problems so it should probably be more about style / docs (after my benchmarks).
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @younesbelkada @fxmarty @amyeroberts