Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP/Accelerate: Training can't be continued from checkpoint with SHARDED_STATE_DICT #26186

Closed
4 tasks
jphme opened this issue Sep 15, 2023 · 6 comments
Closed
4 tasks

Comments

@jphme
Copy link
Contributor

jphme commented Sep 15, 2023

System Info

  • transformers version: 4.34.0.dev0
  • Platform: Linux-5.4.0-156-generic-x86_64-with-glibc2.35
  • Python version: 3.9.18
  • Huggingface_hub version: 0.17.1
  • Safetensors version: 0.3.3
  • Accelerate version: 0.23.0.dev0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1+cu118 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: 4 * V100
  • Using distributed or parallel set-up in script?: FSDP via accelerate

Who can help?

cc @pacman100

I can´t continue Training from Checkpoints that were created with fsdp_state_dict_type: SHARDED_STATE_DICT via FSDP/ Accelerate. The rest of the training (and also model saving after calling trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") when the training has finished) works fine.

This is the error:

  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 2093, in _load_from_checkpoint
    raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
ValueError: Can't find a valid checkpoint at /workspace/models/fsdp_debug/checkpoint-5

My FSDP config:

fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_forward_prefetch: true
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
  limit_all_gathers: true

Checkpoint contents:

-rw-r--r-- 1 root root  635 Sep 15 12:00 config.json
-rw-r--r-- 1 root root  188 Sep 15 12:00 generation_config.json
drwxr-xr-x 2 root root 4.0K Sep 15 12:00 optimizer_0
drwxr-xr-x 2 root root 4.0K Sep 15 12:00 pytorch_model_0
-rw-r--r-- 1 root root  18K Sep 15 12:00 rng_state_0.pth
-rw-r--r-- 1 root root  18K Sep 15 12:00 rng_state_1.pth
-rw-r--r-- 1 root root  18K Sep 15 12:00 rng_state_2.pth
-rw-r--r-- 1 root root  18K Sep 15 12:00 rng_state_3.pth
-rw-r--r-- 1 root root  627 Sep 15 12:00 scheduler.pt
-rw-r--r-- 1 root root  946 Sep 15 12:00 trainer_state.json
-rw-r--r-- 1 root root 4.8K Sep 15 12:00 training_args.bin

#pytorch_model_0
total 13G
-rw-r--r-- 1 root root 3.2G Sep 15 12:00 __0_0.distcp
-rw-r--r-- 1 root root 3.2G Sep 15 12:00 __1_0.distcp
-rw-r--r-- 1 root root 3.2G Sep 15 12:00 __2_0.distcp
-rw-r--r-- 1 root root 3.2G Sep 15 12:00 __3_0.distcp

At first I thought this is just an error because the trainer expects a pytorch_model.bin which isn't available in the directory (see

).

However when trying to call load_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, model, resume_from_checkpoint) directly in _load_from_checkpoint, i get the following error:

  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1694, in _inner_training_loop
    FullyShardedDataParallel.set_state_dict_type(
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 608, in set_state_dict_type
    load_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, model, resume_from_checkpoint)
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/accelerate/utils/fsdp_utils.py", line 129, in load_fsdp_model
    load_result = model.load_state_dict(state_dict)
  File "/root/miniconda3/envs/py3.9/lib/python3.9/contextlib.py", line 126, in __exit__
        self._load_from_checkpoint(resume_from_checkpoint, model)state_dict_config_type = _state_dict_type_to_config[state_dict_type]

  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 2076, in _load_from_checkpoint
next(self.gen)
KeyError  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 720, in state_dict_type
: None
    FullyShardedDataParallel.set_state_dict_type(
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 608, in set_state_dict_type
    state_dict_config_type = _state_dict_type_to_config[state_dict_type]
KeyError: None

Content of self.accelerator.state.fsdp_plugin:

FullyShardedDataParallelPlugin(sharding_strategy=<ShardingStrategy.FULL_SHARD: 1>, backward_prefetch=None, mixed_precision_policy=MixedPrecision(param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True), auto_wrap_policy=None, cpu_offload=CPUOffload(offload_params=False), ignored_modules=None, state_dict_type=<StateDictType.SHARDED_STATE_DICT: 3>, state_dict_config=None, optim_state_dict_config=None, limit_all_gathers=True, use_orig_params=False, param_init_fn=<function FullyShardedDataParallelPlugin.post_init.. at 0x7f66aabbb160>, sync_module_states=True, forward_prefetch=False, activation_checkpointing=False)

Any idea on how to fix this? Many thanks!

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

see above

Expected behavior

Training can be resumed from checkpoints.

@muellerzr
Copy link
Contributor

cc @pacman100

@ArthurZucker
Copy link
Collaborator

I believe this will be fixed by #26180 will review

@jphme
Copy link
Contributor Author

jphme commented Sep 17, 2023

I believe this will be fixed by #26180 will review

many thanks, very timely and it does indeed solve the issue! Commented on the PR with a follow-up issue but will close this as the specific issue is solved by the PR.

@jphme jphme closed this as completed Sep 17, 2023
@jerryjalapeno
Copy link

I am facing this exact issue. What is the script that will consolidate the fsdp model shards as a single file? I have the checkpoint but no way to save the model.

@jphme
Copy link
Contributor Author

jphme commented Sep 19, 2023

I am facing this exact issue. What is the script that will consolidate the fsdp model shards as a single file? I have the checkpoint but no way to save the model.

Try out #26180 (there @pacman100 also linked to the torch methods to directly load sharded state dicts).

Unfortunately, as it currently stands, you can start training, create checkpoints, finish training and save the model but still run OOM when trying to continue from a checkpoint, so if you max out VRAM during your training runs, checkpoints are currently useless with SHARDED_STATE_DICT :/.

@seilk
Copy link

seilk commented Feb 1, 2024

I am facing this exact issue. What is the script that will consolidate the fsdp model shards as a single file? I have the checkpoint but no way to save the model.

Try out #26180 (there @pacman100 also linked to the torch methods to directly load sharded state dicts).

Unfortunately, as it currently stands, you can start training, create checkpoints, finish training and save the model but still run OOM when trying to continue from a checkpoint, so if you max out VRAM during your training runs, checkpoints are currently useless with SHARDED_STATE_DICT :/.

@jphme
Does your statement mean that if a model is trained using FSDP, it cannot be restarted from a saved checkpoint in the middle of training, and must be retrained from iteration 0?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants