Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for OOM Errors during Ultrachat200k Finetuning #2180

Merged
merged 4 commits into from
Mar 14, 2024
Merged

Fix for OOM Errors during Ultrachat200k Finetuning #2180

merged 4 commits into from
Mar 14, 2024

Conversation

Satrat
Copy link
Contributor

@Satrat Satrat commented Mar 13, 2024

The following SparseZoo recipes were causing OOM errors even with FSDP:

  • "zoo:llama2-7b-ultrachat200k_llama2_pretrain-pruned40"
  • "zoo:llama2-7b-ultrachat200k_llama2_pretrain-pruned40_quantized"

The fix was to offload weights to CPU when gathering FSDP params during modifier initialization and finalization

Test

Tested on 6 48GB GPUs. Example only runs a few training samples for ease of testing.

from sparseml.transformers import compress, SparseAutoModelForCausalLM, SparseAutoTokenizer

model = SparseAutoModelForCausalLM.from_pretrained("zoo:llama2-7b-ultrachat200k_llama2_pretrain-base")
teacher = SparseAutoModelForCausalLM.from_pretrained("zoo:llama2-7b-ultrachat200k_llama2_pretrain-base")
tokenizer = SparseAutoTokenizer.from_pretrained("zoo:llama2-7b-ultrachat200k_llama2_pretrain-base")
dataset="open_platypus"
MODEL_STUB="zoo:llama2-7b-ultrachat200k_llama2_pretrain-pruned40"

compress(
    model=model,
    distill_teacher=teacher,
    tokenizer=tokenizer,
    dataset=dataset,
    recipe=MODEL_STUB,
    output_dir="./output",
    gradient_checkpointing = True,
    num_train_epochs=0.02
)

Run with FSDP: accelerate launch --config_file fsdp_config.yaml test.py

@Satrat Satrat marked this pull request as ready for review March 13, 2024 20:46
@Satrat Satrat merged commit 965bdfa into main Mar 14, 2024
12 of 13 checks passed
@Satrat Satrat deleted the oom_debug branch March 14, 2024 03:05
Satrat added a commit that referenced this pull request Mar 14, 2024
* testing fix

* get rid of repeated log

* revert yaml
Satrat added a commit that referenced this pull request Mar 14, 2024
* testing fix

* get rid of repeated log

* revert yaml
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants