-
-
Notifications
You must be signed in to change notification settings - Fork 780
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama 3 8b OOM with GaLore on 2x A100s (Mistral 7b is fine?) #1641
Comments
How much VRAM does mistral 7B use w/ galore on your setup? Keep in mind that llama-3 has a much larger embeddings layer than Mistral (128k vs 32k) which significantly increases VRAM use. I would start by decreasing the batch sizes and using the unsloth gradient checkpointing. micro_batch_size: 4
eval_batch_size: 4
gradient_checkpointing: unsloth |
VRAM usage w/ Mistral 7b seems to be about 58 gigabytes only, on 1x A100: As for Llama 3, I decreased the batch size to 1,set gradient checkpointing to unsloth, and increased gradient accumulation steps to 6. And rented out 8x A6000s. Still OOMs even with 383.9 GB of VRAM available for finetuning an 8b model. This can't be right, can it? Edit: this may be related to issue #1448 Edit 2: Was using the wrong config -- it does not OOM 8x A6000s. However it IS using 273 GB of VRAM. Seems a bit high? |
@e-p-armstrong did you got any workaround |
@Abhis-123 galore and multi-gpus do not play nice. Gotta use a single GPU or multi gpus w/ paged adamw and deepspeed (look at config.qmd for options I think) |
Please check that this issue hasn't been reported before.
Expected Behavior
Llama 3 8b, with only one billion more parameters, should presumably be able to GaLore train at least on 2x A100s (Mistral v0.2 can train on 1x A100).
Current behaviour
Llama 3 8b OOMs immediately when being tuned with GaLore at 8k sequence length even if obscene amounts of compute are thrown at it.
Steps to reproduce
Config yaml
Possible solution
I was able to finetune Llama 3 8b instruct by reducing the sequence length to around 2000 tokens. However unless I'm missing something big that makes Llama very inefficient with GaLore, presumably I should be able to full finetune at 8000 sequence length if I use 2 whole A100s.
Which Operating Systems are you using?
Python Version
3.10.14
axolotl branch-commit
whatever the official docker image is on.
Acknowledgements
The text was updated successfully, but these errors were encountered: