From db8cc744cba94f49528dd648b59d8d759ca36998 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 29 Mar 2024 18:29:55 +0000 Subject: [PATCH] [test_all] fix models with custom generate --- .../models/musicgen/modeling_musicgen.py | 5 +-- .../modeling_musicgen_melody.py | 36 ++++++++++--------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index ca953b5a304e0c..6cd9af5c292cd9 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -2692,10 +2692,7 @@ def generate( if "encoder_outputs" not in model_kwargs: # encoder_outputs are created and added to `model_kwargs` model_kwargs = self._prepare_text_encoder_kwargs_for_generation( - inputs_tensor, - model_kwargs, - model_input_name, - generation_config + inputs_tensor, model_kwargs, model_input_name, generation_config ) if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs: diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 0840635f6535b2..a374bd454e1608 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -1563,8 +1563,6 @@ def generate( batch_size = input_ids.shape[0] // self.num_codebooks # 4. Define other model kwargs - model_kwargs["output_attentions"] = generation_config.output_attentions - model_kwargs["output_hidden_states"] = generation_config.output_hidden_states model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["guidance_scale"] = generation_config.guidance_scale @@ -1662,8 +1660,10 @@ def generate( logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, + output_attentions=generation_config.output_attentions, + output_hidden_states=generation_config.output_hidden_states, output_scores=generation_config.output_scores, + output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, @@ -1688,8 +1688,10 @@ def generate( logits_warper=logits_warper, stopping_criteria=stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, + output_attentions=generation_config.output_attentions, + output_hidden_states=generation_config.output_hidden_states, output_scores=generation_config.output_scores, + output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, @@ -2303,12 +2305,13 @@ def _prepare_encoder_hidden_states_kwargs_for_generation( self, inputs_tensor: torch.Tensor, model_kwargs, - model_input_name: Optional[str] = None, - guidance_scale: Optional[float] = None, + model_input_name: Optional[str], + generation_config: GenerationConfig, ) -> Dict[str, Any]: encoder_hidden_states = None # attention mask is consumed once to produce text conditional hidden states through the text encoder encoder_attention_mask = model_kwargs.pop("attention_mask") + guidance_scale = generation_config.guidance_scale # 1. condition on text if inputs_tensor is not None: @@ -2331,6 +2334,8 @@ def _prepare_encoder_hidden_states_kwargs_for_generation( encoder_kwargs = { argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature } + encoder_kwargs["output_attentions"] = generation_config.output_attentions + encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states # make sure that encoder returns `ModelOutput` model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name @@ -2541,8 +2546,6 @@ def generate( batch_size = inputs_tensor.shape[0] # 4. Define other model kwargs - model_kwargs["output_attentions"] = generation_config.output_attentions - model_kwargs["output_hidden_states"] = generation_config.output_hidden_states model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["guidance_scale"] = generation_config.guidance_scale @@ -2554,10 +2557,7 @@ def generate( if "encoder_hidden_states" not in model_kwargs: # encoder_hidden_states are created and added to `model_kwargs` model_kwargs = self._prepare_encoder_hidden_states_kwargs_for_generation( - inputs_tensor, - model_kwargs, - model_input_name, - guidance_scale=generation_config.guidance_scale, + inputs_tensor, model_kwargs, model_input_name, generation_config ) # 5. Prepare `input_ids` which will be used for auto-regressive generation @@ -2653,13 +2653,15 @@ def generate( ) # 11. run greedy search - outputs = self.greedy_search( + outputs = self._greedy_search( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, + output_attentions=generation_config.output_attentions, + output_hidden_states=generation_config.output_hidden_states, output_scores=generation_config.output_scores, + output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer, @@ -2679,14 +2681,16 @@ def generate( ) # 12. run sample - outputs = self.sample( + outputs = self._sample( input_ids, logits_processor=logits_processor, logits_warper=logits_warper, stopping_criteria=stopping_criteria, pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, + output_attentions=generation_config.output_attentions, + output_hidden_states=generation_config.output_hidden_states, output_scores=generation_config.output_scores, + output_logits=generation_config.output_logits, return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, streamer=streamer,