Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triton RMSNorm #1050

Merged
merged 14 commits into from
Apr 2, 2024
Merged

Triton RMSNorm #1050

merged 14 commits into from
Apr 2, 2024

Conversation

josejg
Copy link
Contributor

@josejg josejg commented Mar 21, 2024

This PR adds support for FlashAttention's Triton implementation of RMSnorm (commonly found in LLaMa-like models).

The flag is enabled by setting parameters.model.norm_type to triton_rmsnorm .

Under a LLaMa-2-7b-like workload (displayed below) using this implementation produces a lift of ~1400 tok/s/gpu while preserving the loss. The improvement is so large because the current RMSnorm implementation is done in vanilla PyTorch (see here).

image

@josejg josejg requested a review from dakinggg March 21, 2024 23:58
tests/models/test_model.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/norm.py Show resolved Hide resolved
llmfoundry/models/layers/norm.py Outdated Show resolved Hide resolved
llmfoundry/models/layers/norm.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, can you add a simple test that checks numerical equivalence to the base rmsnorm impl?

@josejg josejg enabled auto-merge (squash) April 2, 2024 00:02
@josejg josejg merged commit b765b47 into mosaicml:main Apr 2, 2024
9 checks passed
KuuCi pushed a commit that referenced this pull request Apr 18, 2024
* Triton RMS Norm

* fix

* refactor

* Fix test

* no bias

* Update tests/models/test_model.py

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>

* fixes

* Unittest

* fix

* fix2

* fix3

* fix3

---------

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants