Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add warning when using gradient_checkpointing with FSDP full shard #31578

Merged
merged 8 commits into from
Jul 9, 2024
Merged
11 changes: 10 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,7 +1820,7 @@ def __post_init__(self):
raise ValueError("warmup_steps must be either 0 or > 1")

if isinstance(self.fsdp, bool):
self.fsdp = "full_shard" if self.fsdp else ""
self.fsdp = [FSDPOption.FULL_SHARD] if self.fsdp else ""
if isinstance(self.fsdp, str):
self.fsdp = [FSDPOption(s) for s in self.fsdp.split()]
if self.fsdp == [FSDPOption.OFFLOAD]:
Expand All @@ -1831,6 +1831,15 @@ def __post_init__(self):
elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp:
raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.")

if self.gradient_checkpointing and (
FSDPOption.FULL_SHARD in self.fsdp or FSDPOption.HYBRID_SHARD in self.fsdp
):
logger.warning(
"When using FSDP full shard, instead of using `gradient_checkpointing` in TrainingArguments, please"
" use `activation_checkpointing` in `fsdp_config`. The former introduces a redundant AllGather"
" operation in backward pass. Reference: https://github.com/huggingface/transformers/issues/30404"
)

if self.fsdp_config is None:
self.fsdp_config = {}

Expand Down