diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index a30891bddbc10e..5f92ea0bd5a2b3 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -782,6 +782,7 @@ def forward( head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + **kwargs, ): residual = hidden_states diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index c0a8c193d4cc5c..ed6d61793bcc6f 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -628,6 +628,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -642,6 +643,11 @@ def forward( If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model """ residual = hidden_states diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 22735a7a0a38e0..2b9707c5d5f8cc 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -706,6 +706,7 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + **kwargs, ) -> Union[ Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] ]: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 7f3c3070330403..8cbe8fe35926b8 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -301,6 +301,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -590,6 +591,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -687,6 +689,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -701,6 +704,11 @@ def forward( If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model """ residual = hidden_states diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 73738dade42b7e..07942e87e68c29 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -591,6 +591,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -689,6 +690,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -703,8 +705,12 @@ def forward( If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model """ - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 019f69d0eee2d7..fd46513ff84b4f 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -888,6 +888,7 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -906,6 +907,9 @@ def forward( (see `past_key_values`). cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model """ residual = hidden_states diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 0458f916d375ff..235d4f1c38a499 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -666,6 +666,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -680,6 +681,11 @@ def forward( If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model """ residual = hidden_states diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 0c68de968d1610..7f8483d88884b1 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -771,6 +771,7 @@ def forward( use_cache: Optional[bool] = False, past_key_value: Optional[Tuple[torch.Tensor]] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -790,6 +791,9 @@ def forward( past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model """ residual = hidden_states diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 05becfe3a436ae..24b7e878c300d1 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -830,6 +830,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -847,6 +848,11 @@ def forward( If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model """ residual = hidden_states diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index fceee705fdc7e2..51f7dba7d91127 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -734,6 +734,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -749,6 +750,9 @@ def forward( past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model """ residual = hidden_states diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 951b1c35815d8d..6a6e21cacf60c0 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -876,6 +876,7 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -894,6 +895,9 @@ def forward( past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model """ residual = hidden_states diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index c02a90f6a582af..f2b6c8a223fa0d 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -713,6 +713,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -728,6 +729,9 @@ def forward( past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model """ residual = hidden_states diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 9ae55ecdec2dcd..ff5bd851069738 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -14,6 +14,7 @@ import itertools import os +import subprocess import unittest from copy import deepcopy from functools import partial @@ -31,6 +32,7 @@ require_accelerate, require_fsdp, require_torch_accelerator, + require_torch_gpu, require_torch_multi_accelerator, slow, torch_device, @@ -276,6 +278,20 @@ def test_training_and_can_resume_normally(self, state_dict_type): if "learning_rate" in log: self.assertAlmostEqual(log["learning_rate"], log1["learning_rate"], delta=1e-5) + @require_torch_multi_accelerator + @slow + @require_torch_gpu + @require_fsdp + def test_fsdp_cpu_offloading(self): + try: + subprocess.run( + "accelerate launch utils/testing_scripts/fsdp_cpu_offloading.py --config utils/testing_scripts/dummy_fsdp_config.yml", + shell=True, + check=True, + ) + except: # noqa + raise AssertionError("CPU offloading failed with FSDP!") + def run_cmd_and_get_logs(self, use_accelerate, sharding_strategy, launcher, script, args, output_dir): if not use_accelerate: fsdp_args = [