Skip to content

Commit

Permalink
[test_all] fix models with custom generate
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Apr 23, 2024
1 parent 7779082 commit db8cc74
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
5 changes: 1 addition & 4 deletions src/transformers/models/musicgen/modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2692,10 +2692,7 @@ def generate(
if "encoder_outputs" not in model_kwargs:
# encoder_outputs are created and added to `model_kwargs`
model_kwargs = self._prepare_text_encoder_kwargs_for_generation(
inputs_tensor,
model_kwargs,
model_input_name,
generation_config
inputs_tensor, model_kwargs, model_input_name, generation_config
)

if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs:
Expand Down
36 changes: 20 additions & 16 deletions src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
Original file line number Diff line number Diff line change
Expand Up @@ -1563,8 +1563,6 @@ def generate(
batch_size = input_ids.shape[0] // self.num_codebooks

# 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale

Expand Down Expand Up @@ -1662,8 +1660,10 @@ def generate(
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_attentions=generation_config.output_attentions,
output_hidden_states=generation_config.output_hidden_states,
output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
Expand All @@ -1688,8 +1688,10 @@ def generate(
logits_warper=logits_warper,
stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_attentions=generation_config.output_attentions,
output_hidden_states=generation_config.output_hidden_states,
output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
Expand Down Expand Up @@ -2303,12 +2305,13 @@ def _prepare_encoder_hidden_states_kwargs_for_generation(
self,
inputs_tensor: torch.Tensor,
model_kwargs,
model_input_name: Optional[str] = None,
guidance_scale: Optional[float] = None,
model_input_name: Optional[str],
generation_config: GenerationConfig,
) -> Dict[str, Any]:
encoder_hidden_states = None
# attention mask is consumed once to produce text conditional hidden states through the text encoder
encoder_attention_mask = model_kwargs.pop("attention_mask")
guidance_scale = generation_config.guidance_scale

# 1. condition on text
if inputs_tensor is not None:
Expand All @@ -2331,6 +2334,8 @@ def _prepare_encoder_hidden_states_kwargs_for_generation(
encoder_kwargs = {
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
}
encoder_kwargs["output_attentions"] = generation_config.output_attentions
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states

# make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.text_encoder.main_input_name
Expand Down Expand Up @@ -2541,8 +2546,6 @@ def generate(
batch_size = inputs_tensor.shape[0]

# 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
model_kwargs["guidance_scale"] = generation_config.guidance_scale

Expand All @@ -2554,10 +2557,7 @@ def generate(
if "encoder_hidden_states" not in model_kwargs:
# encoder_hidden_states are created and added to `model_kwargs`
model_kwargs = self._prepare_encoder_hidden_states_kwargs_for_generation(
inputs_tensor,
model_kwargs,
model_input_name,
guidance_scale=generation_config.guidance_scale,
inputs_tensor, model_kwargs, model_input_name, generation_config
)

# 5. Prepare `input_ids` which will be used for auto-regressive generation
Expand Down Expand Up @@ -2653,13 +2653,15 @@ def generate(
)

# 11. run greedy search
outputs = self.greedy_search(
outputs = self._greedy_search(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_attentions=generation_config.output_attentions,
output_hidden_states=generation_config.output_hidden_states,
output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
Expand All @@ -2679,14 +2681,16 @@ def generate(
)

# 12. run sample
outputs = self.sample(
outputs = self._sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
stopping_criteria=stopping_criteria,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_attentions=generation_config.output_attentions,
output_hidden_states=generation_config.output_hidden_states,
output_scores=generation_config.output_scores,
output_logits=generation_config.output_logits,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
Expand Down

0 comments on commit db8cc74

Please sign in to comment.