diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 6d68405ab35a24..bc83b131a2e677 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1820,7 +1820,7 @@ def __post_init__(self): raise ValueError("warmup_steps must be either 0 or > 1") if isinstance(self.fsdp, bool): - self.fsdp = "full_shard" if self.fsdp else "" + self.fsdp = [FSDPOption.FULL_SHARD] if self.fsdp else "" if isinstance(self.fsdp, str): self.fsdp = [FSDPOption(s) for s in self.fsdp.split()] if self.fsdp == [FSDPOption.OFFLOAD]: @@ -1831,6 +1831,15 @@ def __post_init__(self): elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp: raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.") + if self.gradient_checkpointing and ( + FSDPOption.FULL_SHARD in self.fsdp or FSDPOption.HYBRID_SHARD in self.fsdp + ): + logger.warning( + "When using FSDP full shard, instead of using `gradient_checkpointing` in TrainingArguments, please" + " use `activation_checkpointing` in `fsdp_config`. The former introduces a redundant AllGather" + " operation in backward pass. Reference: https://github.com/huggingface/transformers/issues/30404" + ) + if self.fsdp_config is None: self.fsdp_config = {}