Skip to content

Commit

Permalink
Fix DeepSpeed Zero 3 Saving (axolotl-ai-cloud#709)
Browse files Browse the repository at this point in the history
* Update train.py

* add zero3 check

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
  • Loading branch information
tokestermw and winglian committed Oct 19, 2023
1 parent 46e03af commit 4f14bc5
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import transformers.modelcard
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
from transformers.deepspeed import is_deepspeed_zero3_enabled

from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
Expand Down Expand Up @@ -134,6 +135,22 @@ def terminate_handler(_, __, model):
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
trainer.save_model(cfg.output_dir)
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
trainer.accelerator.wait_for_everyone()
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)

# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
# The model name saved is `pytorch_model.bin`
unwrapped_model.save_pretrained(
cfg.output_dir,
is_main_process=trainer.accelerator.is_main_process,
save_function=trainer.accelerator.save,
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
)
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
Expand Down

0 comments on commit 4f14bc5

Please sign in to comment.