-
-
Notifications
You must be signed in to change notification settings - Fork 780
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add shifted sparse attention #973
Add shifted sparse attention #973
Conversation
42a0645
to
a6be9cb
Compare
@joecummings do you have time to rebase this onto main? If not, I can take a stab at rebasing later this week. |
yep I'll do this later today! |
cb08b2d
to
7628056
Compare
0412089
to
4135039
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. @NanoCode012 xan you take a quick look please to make sure I didn't miss anything? Thanks
) | ||
|
||
# Modify all llama derived models in one block | ||
if cfg.is_llama_derived_model: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it applies to llama models, do we need to account for mistral here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@joecummings this should work with mistral and mixtral too, right?
cross_entropy=cfg.flash_attn_cross_entropy, | ||
rms_norm=cfg.flash_attn_rms_norm, | ||
if cfg.sample_packing: | ||
if cfg.device not in ["mps", "cpu"] and not inference: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, does it mean, FA won't be enabled for inference mode now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think FA was ever enabled for flash_attention. here's the original code:
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(
packed=cfg.sample_packing,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
)
Thanks for all your work on this @joecummings ! |
Summary
Add shifted sparse attention (w/ flash attention) to enable longer context training w/ less memory overhead.
Paper: https://arxiv.org/pdf/2309.12307.pdf
Code: https://github.com/dvlab-research/LongLoRA/tree/main
Testing
Added test to check for raised
ValueError
ifsample_packing = True
ands2_attention = True
pytest tests/utils/test_models.py::ModelsUtilsTest::test_cfg_throws_error_with_s2_attention_and_sample_packing
Run
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
with the following config changes:[INSERT WANDB LOG HERE]
Follow-ups
embed
andnorm
during LoRA, which improves performance according to the above paper (e.g. LoRA+)