Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating the Flash Attention version to fix cross entropy loss #812

Conversation

ShashankMosaicML
Copy link
Contributor

@ShashankMosaicML ShashankMosaicML commented Dec 19, 2023

The cross entropy loss of Flash Attention v2.3.2 (and lower) throws illegal memory access error when used with large (device train microbatch size X sequence length X vocabulary). To fix this we had reverted back to FA v1's CE loss in this PR (#795). However, we discovered that for very large (device train microbatch size X sequence length X vocabulary), FA v1's CE loss runs into numerical precision errors, causing divergence.

The newer versions of Flash Attention (v2.3.3 and higher) seem to have solved both of these problems, and hence in this PR, we update the repo to start using FA v2.3.6 (the latest version) instead of FA v2.3.2.

The blue loss curve below corresponds to the training run using FA v2.3.6's CE loss, and the pink curve corresponds to training run using FA v1's CE loss.

Screenshot 2023-12-19 at 4 54 55 PM

setup.py Show resolved Hide resolved
setup.py Show resolved Hide resolved
@ShashankMosaicML ShashankMosaicML marked this pull request as ready for review December 20, 2023 00:22
@ShashankMosaicML ShashankMosaicML requested a review from a team as a code owner December 20, 2023 00:22
@dakinggg
Copy link
Collaborator

throughput, memory, and loss before and after:
Screenshot 2023-12-19 at 4 31 04 PM
Screenshot 2023-12-19 at 4 31 10 PM
Screenshot 2023-12-19 at 4 31 18 PM

Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, please add a PR description explaining stuff

@ShashankMosaicML
Copy link
Contributor Author

LGTM, please add a PR description explaining stuff

Done.

@ShashankMosaicML ShashankMosaicML merged commit 2ba9224 into mosaicml:main Dec 20, 2023
10 checks passed
@ShashankMosaicML ShashankMosaicML deleted the shashank/update_FA_version_fix_CE branch December 20, 2023 00:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants