Skip to content

Commit

Permalink
[core / modeling] Fix training bug with PEFT + GC (#28031)
Browse files Browse the repository at this point in the history
fix trainign bug
  • Loading branch information
younesbelkada authored and amyeroberts committed Dec 18, 2023
1 parent c48787f commit d1dec79
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,13 @@ def forward(
seq_length_with_past = seq_length
past_key_values_length = 0

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
Expand Down Expand Up @@ -608,13 +615,6 @@ def forward(

hidden_states = inputs_embeds

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,13 @@ def forward(
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
Expand Down Expand Up @@ -1038,13 +1045,6 @@ def forward(
# embed positions
hidden_states = inputs_embeds

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,13 @@ def forward(
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

past_key_values_length = 0

if use_cache:
Expand Down Expand Up @@ -899,13 +906,6 @@ def forward(

hidden_states = inputs_embeds

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,13 @@ def forward(
seq_length_with_past = seq_length
past_key_values_length = 0

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
Expand Down Expand Up @@ -635,13 +642,6 @@ def forward(

hidden_states = inputs_embeds

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,13 @@ def forward(

past_key_values_length = 0

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
Expand Down Expand Up @@ -890,13 +897,6 @@ def forward(

hidden_states = inputs_embeds

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
Expand Down

0 comments on commit d1dec79

Please sign in to comment.