From c3f3a8417506c9ed53b8bf27b6955eeb57b47adf Mon Sep 17 00:00:00 2001 From: Yun Dai Date: Mon, 24 Jun 2024 17:14:10 +0000 Subject: [PATCH 1/6] add warning when using with FSDP full shard --- src/transformers/training_args.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 41a9607e312105..03665a414806df 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1783,7 +1783,7 @@ def __post_init__(self): raise ValueError("warmup_steps must be either 0 or > 1") if isinstance(self.fsdp, bool): - self.fsdp = "full_shard" if self.fsdp else "" + self.fsdp = FSDPOption.FULL_SHARD if self.fsdp else "" if isinstance(self.fsdp, str): self.fsdp = [FSDPOption(s) for s in self.fsdp.split()] if self.fsdp == [FSDPOption.OFFLOAD]: @@ -1793,6 +1793,13 @@ def __post_init__(self): ) elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp: raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.") + + if FSDPOption.FULL_SHARD in self.fsdp and self.gradient_checkpointing: + logger.warning( + "When using FSDP full shard, instead of using `gradient_checkpointing` in TrainingArguments, please" + " use `activation_checkpointing` in `fsdp_config`. The former one introduces a redundant AllGather" + " operation in backward pass. Reference: https://github.com/huggingface/transformers/issues/30404" + ) if self.fsdp_config is None: self.fsdp_config = {} From f1e9321ce601b6572ee804dd494fb9fedfaaa098 Mon Sep 17 00:00:00 2001 From: Yun Dai Date: Mon, 24 Jun 2024 11:19:00 -0700 Subject: [PATCH 2/6] fix style --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 03665a414806df..088549c9dc2048 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1793,7 +1793,7 @@ def __post_init__(self): ) elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp: raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.") - + if FSDPOption.FULL_SHARD in self.fsdp and self.gradient_checkpointing: logger.warning( "When using FSDP full shard, instead of using `gradient_checkpointing` in TrainingArguments, please" From ea734294cf7d02fe72fef1707d8acd55939f79e1 Mon Sep 17 00:00:00 2001 From: Yun Dai Date: Wed, 26 Jun 2024 10:12:37 -0700 Subject: [PATCH 3/6] Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 088549c9dc2048..26729501d9a355 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1797,7 +1797,7 @@ def __post_init__(self): if FSDPOption.FULL_SHARD in self.fsdp and self.gradient_checkpointing: logger.warning( "When using FSDP full shard, instead of using `gradient_checkpointing` in TrainingArguments, please" - " use `activation_checkpointing` in `fsdp_config`. The former one introduces a redundant AllGather" + " use `activation_checkpointing` in `fsdp_config`. The former introduces a redundant AllGather" " operation in backward pass. Reference: https://github.com/huggingface/transformers/issues/30404" ) From 8bb8d4b7e06cda5ebc41800a7a20a59337887ca9 Mon Sep 17 00:00:00 2001 From: Yun Dai Date: Wed, 26 Jun 2024 10:12:44 -0700 Subject: [PATCH 4/6] Update src/transformers/training_args.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 26729501d9a355..085f94f47530a4 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1783,7 +1783,7 @@ def __post_init__(self): raise ValueError("warmup_steps must be either 0 or > 1") if isinstance(self.fsdp, bool): - self.fsdp = FSDPOption.FULL_SHARD if self.fsdp else "" + self.fsdp = [FSDPOption.FULL_SHARD] if self.fsdp else "" if isinstance(self.fsdp, str): self.fsdp = [FSDPOption(s) for s in self.fsdp.split()] if self.fsdp == [FSDPOption.OFFLOAD]: From 07b1b828a092fa3654b875de510000b408208b3c Mon Sep 17 00:00:00 2001 From: Yun Dai Date: Wed, 26 Jun 2024 10:15:15 -0700 Subject: [PATCH 5/6] add hybrid shard warn --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 085f94f47530a4..861bc499ffab0a 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1794,7 +1794,7 @@ def __post_init__(self): elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp: raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.") - if FSDPOption.FULL_SHARD in self.fsdp and self.gradient_checkpointing: + if self.gradient_checkpointing and (FSDPOption.FULL_SHARD in self.fsdp or FSDPOption.HYBRID_SHARD in self.fsdp): logger.warning( "When using FSDP full shard, instead of using `gradient_checkpointing` in TrainingArguments, please" " use `activation_checkpointing` in `fsdp_config`. The former introduces a redundant AllGather" From a1fd23eb98034e0bbf2fe5fe844e444db2d011cf Mon Sep 17 00:00:00 2001 From: Yun Dai Date: Tue, 9 Jul 2024 14:43:17 -0700 Subject: [PATCH 6/6] fix style --- src/transformers/training_args.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 553c7df709acec..bc83b131a2e677 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1831,7 +1831,9 @@ def __post_init__(self): elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp: raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.") - if self.gradient_checkpointing and (FSDPOption.FULL_SHARD in self.fsdp or FSDPOption.HYBRID_SHARD in self.fsdp): + if self.gradient_checkpointing and ( + FSDPOption.FULL_SHARD in self.fsdp or FSDPOption.HYBRID_SHARD in self.fsdp + ): logger.warning( "When using FSDP full shard, instead of using `gradient_checkpointing` in TrainingArguments, please" " use `activation_checkpointing` in `fsdp_config`. The former introduces a redundant AllGather"