Skip to content

Commit

Permalink
set fsdp state dict (axolotl-ai-cloud#584)
Browse files Browse the repository at this point in the history
Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
  • Loading branch information
jphme and jphme committed Sep 15, 2023
1 parent beb7e76 commit 9c40dbc
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ def terminate_handler(_, __, model):

LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")

if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")

if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload()
Expand Down

0 comments on commit 9c40dbc

Please sign in to comment.