From be75668400e116bfaac0e17d6363c009e103b8eb Mon Sep 17 00:00:00 2001 From: Jan Philipp Harries <2862336+jphme@users.noreply.github.com> Date: Fri, 15 Sep 2023 23:47:36 +0200 Subject: [PATCH] set fsdp state dict (#584) Co-authored-by: Jan Philipp Harries --- src/axolotl/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index fa6dbceaf..5ed5837f2 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -117,6 +117,10 @@ def terminate_handler(_, __, model): LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") + if trainer.is_fsdp_enabled: + trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") + LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.") + if cfg.relora_steps: if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): model = model.merge_and_unload()