Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPT2] Add SDPA support #31172

Merged
merged 16 commits into from
Jun 19, 2024
Merged

[GPT2] Add SDPA support #31172

merged 16 commits into from
Jun 19, 2024

Conversation

vasqu
Copy link
Contributor

@vasqu vasqu commented Jun 1, 2024

What does this PR do?

Adds torch's SDPA to the GPT2 model architecture (as another attention module). Possibly relevant #28005.

I'll be adding some benchmarks like in #31031 for gpt2-large (same setup). Docs will be edited afterwards and I'll share my results in this thread. I've checked the two most important tests that validate sdpa vs eager:

  • RUN_SLOW=True pytest tests/models/gpt2 -k "test_eager_matches_sdpa_generate" -s -vvvvv
  • RUN_SLOW=True pytest tests/models/gpt2 -k "test_eager_matches_sdpa_inference" -s -vvvvv

Both pass without any problems so it should probably be more about style / docs (after my benchmarks).

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker @younesbelkada @fxmarty @amyeroberts

@vasqu vasqu force-pushed the gpt2-sdpa branch 3 times, most recently from 2357341 to de45320 Compare June 1, 2024 01:24
@vasqu vasqu changed the title GPT2 Add SDPA support [GPT2] Add SDPA support Jun 1, 2024
@vasqu
Copy link
Contributor Author

vasqu commented Jun 1, 2024

Training benchmark:

Batch size Seq len Time per batch (Eager - s) Time per batch (SDPA - s) Speedup (%) Eager peak mem (MB) SDPA peak mem (MB) Mem saving (%)
1 128 0.039 0.032 23.042 3482.32 3494.62 -0.352
1 256 0.073 0.059 25.15 3546.66 3552.6 -0.167
1 512 0.155 0.118 30.96 4230.1 3665.59 15.4
1 1024 0.316 0.209 50.839 8682.26 4881.09 77.875
2 128 0.07 0.06 15.324 3557.8 3545.91 0.335
2 256 0.143 0.122 16.53 3901.5 3657.68 6.666
2 512 0.267 0.213 25.626 7062.21 4876.47 44.822
2 1024 OOM 0.404 / OOM 8096.35 SDPA does not OOM
4 128 0.134 0.128 4.412 3675.79 3648.72 0.742
4 256 0.243 0.217 12.292 6129.76 4871.12 25.839
4 512 0.494 0.406 21.687 12466.6 8102.64 53.858
4 1024 OOM 0.795 / OOM 14568.2 SDPA does not OOM

Inference benchmark:

Batch size Seq len Per token latency Eager (ms) Per token latency SDPA (ms) Speedup (%) Mem Eager (MB) Mem SDPA (MB) Mem saved (%)
1 128 7.991 6.968 14.681 1685.2 1701.32 -0.947
1 256 8.462 7.199 17.536 1745.49 1770.78 -1.428
1 512 8.68 7.853 10.529 1907.69 1921.29 -0.708
1 768 9.101 8.365 8.791 2032.93 2068.12 -1.701
2 128 9.169 9.001 1.861 1803.84 1811.4 -0.418
2 256 9.907 9.78 1.294 1907.72 1921.44 -0.714
2 512 11.519 11.644 -1.071 2176.86 2197.75 -0.951
2 768 13.022 13.407 -2.873 2464.3 2491.06 -1.074
4 128 10.097 9.831 2.709 1942.25 1985.13 -2.16
4 256 11.599 11.398 1.764 2177.28 2197.86 -0.937
4 512 14.653 14.45 1.411 2753.16 2772.57 -0.7
4 768 17.846 17.617 1.299 3327.04 3343.97 -0.506

@vasqu
Copy link
Contributor Author

vasqu commented Jun 1, 2024

I'm a bit surprised with the inference results (regarding memory) but I don't have anything else to add (to the code and/or docs). Tests seem to pass including the sdpa vs eager slow tests (at least locally).

Aside from that, commit d963ad5 failed on the tests for the hub. This seems very weird to me and may need a separate investigation since my code shouldn't have affected the test (shown by the empty commit 91fe533 right afterwards which passes). Might be a bug in the trainer or somewhere along there.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this contribution and for adding the extensive benchmark ! I left one open question, what do you think ?

src/transformers/models/gpt2/modeling_gpt2.py Show resolved Hide resolved
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot !

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some small nits but overall good job! 👍🏻

src/transformers/models/gpt2/modeling_gpt2.py Outdated Show resolved Hide resolved
src/transformers/models/gpt2/modeling_gpt2.py Outdated Show resolved Hide resolved
src/transformers/models/gpt2/modeling_gpt2.py Show resolved Hide resolved
src/transformers/models/gpt2/modeling_gpt2.py Outdated Show resolved Hide resolved
vasqu added 2 commits June 3, 2024 16:18
- only save _attn_implentation once
- remove unnecessary comment
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you simply run the slow GPT2 tests to make sure everything is alright? As SDPA will be the new default!

@vasqu
Copy link
Contributor Author

vasqu commented Jun 6, 2024

Did that commit 4811cb5 work or is it done locally only? @ArthurZucker

@vasqu
Copy link
Contributor Author

vasqu commented Jun 7, 2024

Ok doesn't look like it ran in slow 😓 i guess i need a different commit msg or perms?

@younesbelkada
Copy link
Contributor

@vasqu if you have access to a GPU you can run them locally with RUN_SLOW=1 pytest tests/models/gpt/test_modeling_gpt2.py

@vasqu
Copy link
Contributor Author

vasqu commented Jun 7, 2024

Ah ic, i will try and report back thx @younesbelkada; I don't have Flash Attention enabled tho (if that's an issue). Would require me to update my cuda which is a pain.

@younesbelkada
Copy link
Contributor

That's shouldn't be an issue IMO, we just want to make sure the currently slow integration tests still pass with SDPA !

@vasqu
Copy link
Contributor Author

vasqu commented Jun 7, 2024

slow_test_gpt2_sample.txt
slow_test_gpt2_contrastive.txt
slow_test_gpt2.txt

When run all in one, three tests fail. Not sure why because when run in isolation only test_gpt2_sample_max_time fails.

Edit: Seeing that contrastive search and sampling test fail due to too low VRAM 😆 Works in isolation tho.

@vasqu
Copy link
Contributor Author

vasqu commented Jun 7, 2024

start = datetime.datetime.now()
model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
duration = datetime.datetime.now() - start
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

Should be changed with

        start = datetime.datetime.now()
        model.generate(input_ids, do_sample=False, max_time=None, max_length=256)
        duration = datetime.datetime.now() - start
        self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
        self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))

That would follow the pattern of all assertions before that and then the tests also pass.

Comment on lines -835 to +836
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
self.assertGreater(duration, datetime.timedelta(seconds=MAX_TIME))
self.assertLess(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only failing test without this modification. I'm not sure if this is how it is intended. Maybe assertGreater should be assertLess instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this change is totally fine, now SDPA is used so the generation is faster

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this !

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! thanks for adding

# Attention mask.
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good!

@ArthurZucker ArthurZucker merged commit b275a41 into huggingface:main Jun 19, 2024
22 checks passed
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vasqu vasqu deleted the gpt2-sdpa branch June 19, 2024 08:14
itazap pushed a commit that referenced this pull request Jun 20, 2024
* `gpt2` sdpa support

* fix (at least) one test, style, repo consistency

* fix sdpa mask in forward --> fixes generation

* test

* test2

* test3

* test4

* simplify shapes for attn mask creation and small comments

* hub fail test

* benchmarks

* flash attn 2 mask should not be inverted on enc-dec setup

* fix comment

* apply some suggestion from code review

- only save _attn_implentation once
- remove unnecessary comment

* change elif logic

* [run-slow] gpt2

* modify `test_gpt2_sample_max_time` to follow previous assertion patterns
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants