diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c65511db16854d..ebe968b7ac4ec9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -754,12 +754,12 @@ def _get_logits_warper( warpers = LogitsProcessorList() # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a - # better score (i.e. keep len(list(generation_config.eos_token_id)) + 1) + # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1) if generation_config.num_beams > 1: - if isinstance(generation_config.eos_token_id, list): - min_tokens_to_keep = len(generation_config.eos_token_id) + 1 - elif isinstance(generation_config.eos_token_id, torch.Tensor): - min_tokens_to_keep = generation_config.eos_token_id.shape[0] + 1 + if isinstance(generation_config._eos_token_tensor, list): + min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1 + elif isinstance(generation_config._eos_token_tensor, torch.Tensor): + min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1 else: min_tokens_to_keep = 2 else: @@ -863,31 +863,31 @@ def _get_logits_processor( processors.append( NoBadWordsLogitsProcessor( generation_config.bad_words_ids, - generation_config.eos_token_id, + generation_config._eos_token_tensor, ) ) if ( generation_config.min_length is not None - and generation_config.eos_token_id is not None + and generation_config._eos_token_tensor is not None and generation_config.min_length > 0 ): processors.append( MinLengthLogitsProcessor( generation_config.min_length, - generation_config.eos_token_id, + generation_config._eos_token_tensor, device=device, ) ) if ( generation_config.min_new_tokens is not None - and generation_config.eos_token_id is not None + and generation_config._eos_token_tensor is not None and generation_config.min_new_tokens > 0 ): processors.append( MinNewTokensLengthLogitsProcessor( input_ids_seq_length, generation_config.min_new_tokens, - generation_config.eos_token_id, + generation_config._eos_token_tensor, device=device, ) ) @@ -918,7 +918,7 @@ def _get_logits_processor( processors.append( ExponentialDecayLengthPenalty( generation_config.exponential_decay_length_penalty, - generation_config.eos_token_id, + generation_config._eos_token_tensor, input_ids_seq_length, ) ) @@ -997,8 +997,8 @@ def _get_stopping_criteria( "stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`." ) criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer)) - if generation_config.eos_token_id is not None: - criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id)) + if generation_config._eos_token_tensor is not None: + criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor)) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) return criteria @@ -1349,13 +1349,15 @@ def _prepare_generation_config( self, generation_config: Optional[GenerationConfig], **kwargs: Dict ) -> Tuple[GenerationConfig, Dict]: """ - Prepares the base generation config, then applies any generation configuration options from kwargs. + Prepares the base generation config, then applies any generation configuration options from kwargs. This + function handles retrocompatibility with respect to configuration files. """ # TODO joao: when we can detect `fullgraph=True` in `torch.compile` (https://github.com/pytorch/pytorch/pull/120400) # replace `is_torchdynamo_compiling` by the corresponding check. As it is, we are being too restrictive with # the parameterization in `fullgraph=False` so as to enable `fullgraph=True`. # priority: `generation_config` argument > `model.generation_config` (the default generation config) + using_model_generation_config = False if generation_config is None: # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior, # three conditions must be met @@ -1378,6 +1380,7 @@ def _prepare_generation_config( " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )" ) self.generation_config = new_generation_config + using_model_generation_config = True generation_config = self.generation_config # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` @@ -1395,6 +1398,16 @@ def _prepare_generation_config( else: generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) + # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model + if not using_model_generation_config: + if generation_config.bos_token_id is None: + generation_config.bos_token_id = self.generation_config.bos_token_id + if generation_config.eos_token_id is None: + generation_config.eos_token_id = self.generation_config.eos_token_id + if generation_config.pad_token_id is None: + generation_config.pad_token_id = self.generation_config.pad_token_id + if generation_config.decoder_start_token_id is None: + generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id return generation_config, model_kwargs @@ -1493,52 +1506,46 @@ def _prepare_special_tokens( function). However, if called outside `generate`, consider creating a copy of `generation_config` first. """ - # Convert special tokens to tensors (if they exist either in kwargs or in self.config) - def _tensor_or_none(token_kwargs, token_self, device=None): - if device is None: - device = self.device - - token = token_kwargs if token_kwargs is not None else token_self + # Convert special tokens to tensors + def _tensor_or_none(token, device=None): if token is None: return token - elif isinstance(token, torch.Tensor): - return token.to(device) + device = device if device is not None else self.device + if isinstance(token, torch.Tensor): + return token.to(device) return torch.tensor(token, device=device, dtype=torch.long) - bos_token_id = _tensor_or_none( - generation_config.bos_token_id, self.generation_config.bos_token_id, device=device - ) - eos_token_id = _tensor_or_none( - generation_config.eos_token_id, self.generation_config.eos_token_id, device=device - ) - pad_token_id = _tensor_or_none( - generation_config.pad_token_id, self.generation_config.pad_token_id, device=device - ) - decoder_start_token_id = _tensor_or_none( - generation_config.decoder_start_token_id, self.generation_config.decoder_start_token_id, device=device - ) + bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device) + eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device) + pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device) + decoder_start_token_tensor = _tensor_or_none(generation_config.decoder_start_token_id, device=device) # for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892) if self.config.is_encoder_decoder: - decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id + decoder_start_token_tensor = ( + decoder_start_token_tensor if decoder_start_token_tensor is not None else bos_token_tensor + ) # We can have more than one eos token. Always treat it as a 1D tensor (when it exists). - if eos_token_id is not None and eos_token_id.ndim == 0: - eos_token_id = eos_token_id.unsqueeze(0) + if eos_token_tensor is not None and eos_token_tensor.ndim == 0: + eos_token_tensor = eos_token_tensor.unsqueeze(0) # Set pad token if unset (and there are conditions to do so) - if pad_token_id is None and eos_token_id is not None: + if pad_token_tensor is None and eos_token_tensor is not None: if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: logger.warning( "The attention mask and the pad token id were not set. As a consequence, you may observe " "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." ) - pad_token_id = eos_token_id[0] - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.") + pad_token_tensor = eos_token_tensor[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") # we can't infer attn mask if pad token is set to be eos token in model's generation config - if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any(): + if ( + eos_token_tensor is not None + and torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any() + ): if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: logger.warning_once( "The attention mask is not set and cannot be inferred from input because pad token is same as eos token." @@ -1547,21 +1554,26 @@ def _tensor_or_none(token_kwargs, token_self, device=None): ) # Sanity checks/warnings - if self.config.is_encoder_decoder and decoder_start_token_id is None: + if self.config.is_encoder_decoder and decoder_start_token_tensor is None: raise ValueError( "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." ) - if eos_token_id is not None and (torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any()): + if eos_token_tensor is not None and ( + torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any() + ): logger.warning( - f"`eos_token_id` should consist of positive integers, but is {eos_token_id}. Your generation will not " + f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation will not " "stop until the maximum length is reached. Depending on other flags, it may even crash." ) # Update generation config with the updated special tokens tensors - generation_config.bos_token_id = bos_token_id - generation_config.eos_token_id = eos_token_id - generation_config.pad_token_id = pad_token_id - generation_config.decoder_start_token_id = decoder_start_token_id + # NOTE: this must be written into a different attribute name than the one holding the original special tokens + # (in their non-tensor form), in order to enable end-to-end compilation. See + # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations + generation_config._bos_token_tensor = bos_token_tensor + generation_config._eos_token_tensor = eos_token_tensor + generation_config._pad_token_tensor = pad_token_tensor + generation_config._decoder_start_token_tensor = decoder_start_token_tensor @torch.no_grad() def generate( @@ -1696,10 +1708,10 @@ def generate( # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. if ( - generation_config.pad_token_id is not None + generation_config._pad_token_tensor is not None and batch_size > 1 and len(inputs_tensor.shape) == 2 - and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 + and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0 ): logger.warning( "A decoder-only architecture is being used, but right-padding was detected! For correct " @@ -1716,7 +1728,7 @@ def generate( if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id + inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor ) if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: @@ -1731,7 +1743,7 @@ def generate( batch_size=batch_size, model_input_name=model_input_name, model_kwargs=model_kwargs, - decoder_start_token_id=generation_config.decoder_start_token_id, + decoder_start_token_id=generation_config._decoder_start_token_tensor, device=inputs_tensor.device, ) else: @@ -2279,7 +2291,7 @@ def _dola_decoding( raise ValueError("DoLa decoding is only available for decoder-only models.") # init values - pad_token_id = generation_config.pad_token_id + pad_token_id = generation_config._pad_token_tensor output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -2486,7 +2498,7 @@ def _contrastive_search( has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) top_k = generation_config.top_k penalty_alpha = generation_config.penalty_alpha - pad_token_id = generation_config.pad_token_id + pad_token_id = generation_config._pad_token_tensor output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -2877,7 +2889,7 @@ def _sample( `model.config.is_encoder_decoder=True`. """ # init values - pad_token_id = generation_config.pad_token_id + pad_token_id = generation_config._pad_token_tensor output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -3084,8 +3096,8 @@ def _beam_search( `model.config.is_encoder_decoder=True`. """ # init values - pad_token_id = generation_config.pad_token_id - eos_token_id = generation_config.eos_token_id + pad_token_id = generation_config._pad_token_tensor + eos_token_id = generation_config._eos_token_tensor output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -3366,8 +3378,8 @@ def _group_beam_search( `model.config.is_encoder_decoder=True`. """ # init values - pad_token_id = generation_config.pad_token_id - eos_token_id = generation_config.eos_token_id + pad_token_id = generation_config._pad_token_tensor + eos_token_id = generation_config._eos_token_tensor output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores @@ -3658,8 +3670,8 @@ def _constrained_beam_search( `model.config.is_encoder_decoder=True`. """ # init values - pad_token_id = generation_config.pad_token_id - eos_token_id = generation_config.eos_token_id + pad_token_id = generation_config._pad_token_tensor + eos_token_id = generation_config._eos_token_tensor output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 0102d1c267c7ad..7aaaeb461c1343 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1539,75 +1539,43 @@ def generate( logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: - if model_kwargs.get("attention_mask", None) is None: - logger.warning( - "The attention mask and the pad token id were not set. As a consequence, you may observe " - "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." - ) - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - generation_config.pad_token_id = eos_token_id + requires_attention_mask = "encoder_outputs" not in model_kwargs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None - # 3. Define model inputs - # inputs_tensor has to be defined - # model_input_name is defined if model-specific keyword input is passed - # otherwise model_input_name is None - # all model-specific keyword inputs are removed from `model_kwargs` + # 3. Define model inputs` input_ids, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = input_ids.shape[0] // self.num_codebooks - kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device) # 4. Define other model kwargs model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["guidance_scale"] = generation_config.guidance_scale - requires_attention_mask = "encoder_outputs" not in model_kwargs if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - input_ids, generation_config.pad_token_id, generation_config.eos_token_id + input_ids, generation_config._pad_token_tensor, generation_config._eos_token_tensor ) # 5. Prepare `max_length` depending on other stopping criteria. - input_ids_seq_length = input_ids.shape[-1] + input_ids_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: - logger.warning( - f"Using the model-agnostic default `max_length` (={generation_config.max_length}) " - "to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation." - ) - elif generation_config.max_new_tokens is not None: - if not has_default_max_length: - logger.warning( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" - ) - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - - if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: - raise ValueError( - f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" - f" the maximum length ({generation_config.max_length})" - ) - if input_ids_seq_length >= generation_config.max_length: - logger.warning( - f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=input_ids, + input_ids_length=input_ids_length, + ) # 6. Prepare `input_ids` which will be used for auto-regressive generation # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen) input_ids, delay_pattern_mask = self.build_delay_pattern_mask( input_ids, - pad_token_id=generation_config.decoder_start_token_id, + pad_token_id=generation_config._decoder_start_token_tensor, max_length=generation_config.max_length, ) @@ -1628,7 +1596,7 @@ def generate( # 9. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, + input_ids_seq_length=input_ids_length, encoder_input_ids=input_ids, prefix_allowed_tokens_fn=None, logits_processor=logits_processor, @@ -1682,7 +1650,7 @@ def generate( output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"]) # revert the pattern delay mask by filtering the pad token id - output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape( + output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape( batch_size, self.num_codebooks, -1 ) @@ -2590,39 +2558,23 @@ def generate( logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: - if model_kwargs.get("attention_mask", None) is None: - logger.warning( - "The attention mask and the pad token id were not set. As a consequence, you may observe " - "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." - ) - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - generation_config.pad_token_id = eos_token_id + requires_attention_mask = "encoder_outputs" not in model_kwargs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None # 3. Define model inputs - # inputs_tensor has to be defined - # model_input_name is defined if model-specific keyword input is passed - # otherwise model_input_name is None - # all model-specific keyword inputs are removed from `model_kwargs` inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] - kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device) # 4. Define other model kwargs model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["guidance_scale"] = generation_config.guidance_scale - requires_attention_mask = "encoder_outputs" not in model_kwargs - if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id + inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor ) if "encoder_outputs" not in model_kwargs: @@ -2642,45 +2594,28 @@ def generate( batch_size=batch_size, model_input_name=model_input_name, model_kwargs=model_kwargs, - decoder_start_token_id=generation_config.decoder_start_token_id, - bos_token_id=generation_config.bos_token_id, + decoder_start_token_id=generation_config._decoder_start_token_tensor, + bos_token_id=generation_config._bos_token_tensor, device=inputs_tensor.device, ) # 6. Prepare `max_length` depending on other stopping criteria. - input_ids_seq_length = input_ids.shape[-1] + input_ids_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None: - logger.warning( - f"Using the model-agnostic default `max_length` (={generation_config.max_length}) " - "to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation." - ) - elif generation_config.max_new_tokens is not None: - if not has_default_max_length: - logger.warning( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" - ) - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - - if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: - raise ValueError( - f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" - f" the maximum length ({generation_config.max_length})" - ) - if input_ids_seq_length >= generation_config.max_length: - logger.warning( - f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=inputs_tensor, + input_ids_length=input_ids_length, + ) # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen) input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( input_ids, - pad_token_id=generation_config.decoder_start_token_id, + pad_token_id=generation_config._decoder_start_token_tensor, max_length=generation_config.max_length, ) # stash the delay mask so that we don't have to recompute in each forward pass @@ -2701,7 +2636,7 @@ def generate( # 9. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, + input_ids_seq_length=input_ids_length, encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=None, logits_processor=logits_processor, @@ -2756,7 +2691,7 @@ def generate( output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) # revert the pattern delay mask by filtering the pad token id - output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape( + output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape( batch_size, self.decoder.num_codebooks, -1 ) diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 3140b9f286448f..eafb7baad8f740 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -1375,6 +1375,7 @@ def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): return input_ids @torch.no_grad() + # Ignore copy def generate( self, inputs: Optional[torch.Tensor] = None, @@ -1460,75 +1461,43 @@ def generate( logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: - if model_kwargs.get("attention_mask", None) is None: - logger.warning( - "The attention mask and the pad token id were not set. As a consequence, you may observe " - "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." - ) - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - generation_config.pad_token_id = eos_token_id + requires_attention_mask = "encoder_outputs" not in model_kwargs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None - # 3. Define model inputs - # inputs_tensor has to be defined - # model_input_name is defined if model-specific keyword input is passed - # otherwise model_input_name is None - # all model-specific keyword inputs are removed from `model_kwargs` + # 3. Define model inputs` input_ids, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = input_ids.shape[0] // self.num_codebooks - kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device) # 4. Define other model kwargs model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["guidance_scale"] = generation_config.guidance_scale - # Ignore copy - if model_kwargs.get("attention_mask", None) is None: + if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - input_ids, generation_config.pad_token_id, generation_config.eos_token_id + input_ids, generation_config._pad_token_tensor, generation_config._eos_token_tensor ) # 5. Prepare `max_length` depending on other stopping criteria. - input_ids_seq_length = input_ids.shape[-1] + input_ids_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: - logger.warning( - f"Using the model-agnostic default `max_length` (={generation_config.max_length}) " - "to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation." - ) - elif generation_config.max_new_tokens is not None: - if not has_default_max_length: - logger.warning( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" - ) - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - - if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: - raise ValueError( - f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" - f" the maximum length ({generation_config.max_length})" - ) - if input_ids_seq_length >= generation_config.max_length: - logger.warning( - f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=input_ids, + input_ids_length=input_ids_length, + ) # 6. Prepare `input_ids` which will be used for auto-regressive generation - # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen) + # Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Musicgen) input_ids, delay_pattern_mask = self.build_delay_pattern_mask( input_ids, - pad_token_id=generation_config.decoder_start_token_id, + pad_token_id=generation_config._decoder_start_token_tensor, max_length=generation_config.max_length, ) @@ -1549,7 +1518,7 @@ def generate( # 9. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, + input_ids_seq_length=input_ids_length, encoder_input_ids=input_ids, prefix_allowed_tokens_fn=None, logits_processor=logits_processor, @@ -1603,7 +1572,7 @@ def generate( output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"]) # revert the pattern delay mask by filtering the pad token id - output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape( + output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape( batch_size, self.num_codebooks, -1 ) @@ -2397,7 +2366,7 @@ def generate( Custom stopping criteria that complement the default stopping criteria built from arguments and a generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. - synced_gpus (`bool`, *optional*): + synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed @@ -2414,18 +2383,14 @@ def generate( If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible [`~utils.ModelOutput`] types are: - - [`~generation.GreedySearchDecoderOnlyOutput`], - - [`~generation.SampleDecoderOnlyOutput`], - - [`~generation.BeamSearchDecoderOnlyOutput`], - - [`~generation.BeamSampleDecoderOnlyOutput`] + - [`~generation.GenerateDecoderOnlyOutput`], + - [`~generation.GenerateBeamDecoderOnlyOutput`] If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible [`~utils.ModelOutput`] types are: - - [`~generation.GreedySearchEncoderDecoderOutput`], - - [`~generation.SampleEncoderDecoderOutput`], - - [`~generation.BeamSearchEncoderDecoderOutput`], - - [`~generation.BeamSampleEncoderDecoderOutput`] + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] """ # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects if generation_config is None: @@ -2440,37 +2405,23 @@ def generate( logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: - if model_kwargs.get("attention_mask", None) is None: - logger.warning( - "The attention mask and the pad token id were not set. As a consequence, you may observe " - "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." - ) - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") - generation_config.pad_token_id = eos_token_id + requires_attention_mask = "encoder_outputs" not in model_kwargs + kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None # 3. Define model inputs - # inputs_tensor has to be defined - # model_input_name is defined if model-specific keyword input is passed - # otherwise model_input_name is None - # all model-specific keyword inputs are removed from `model_kwargs` inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] - kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=inputs_tensor.device) # 4. Define other model kwargs model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["guidance_scale"] = generation_config.guidance_scale - if model_kwargs.get("attention_mask", None) is None: + if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id + inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor ) if "encoder_hidden_states" not in model_kwargs: @@ -2484,46 +2435,28 @@ def generate( batch_size=batch_size, model_input_name=model_input_name, model_kwargs=model_kwargs, - decoder_start_token_id=generation_config.decoder_start_token_id, - bos_token_id=generation_config.bos_token_id, + decoder_start_token_id=generation_config._decoder_start_token_tensor, + bos_token_id=generation_config._bos_token_tensor, device=inputs_tensor.device, ) # 6. Prepare `max_length` depending on other stopping criteria. - input_ids_seq_length = input_ids.shape[-1] - + input_ids_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None: - logger.warning( - f"Using the model-agnostic default `max_length` (={generation_config.max_length}) " - "to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation." - ) - elif generation_config.max_new_tokens is not None: - if not has_default_max_length: - logger.warning( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" - ) - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - - if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: - raise ValueError( - f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" - f" the maximum length ({generation_config.max_length})" - ) - if input_ids_seq_length >= generation_config.max_length: - logger.warning( - f"Input length of decoder_input_ids is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=inputs_tensor, + input_ids_length=input_ids_length, + ) - # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Musicgen Melody) + # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen) input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( input_ids, - pad_token_id=generation_config.decoder_start_token_id, + pad_token_id=generation_config._decoder_start_token_tensor, max_length=generation_config.max_length, ) # stash the delay mask so that we don't have to recompute in each forward pass @@ -2544,7 +2477,7 @@ def generate( # 9. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, + input_ids_seq_length=input_ids_length, encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=None, logits_processor=logits_processor, @@ -2599,7 +2532,7 @@ def generate( output_ids = self.decoder.apply_delay_pattern_mask(output_ids, model_kwargs["decoder_delay_pattern_mask"]) # revert the pattern delay mask by filtering the pad token id - output_ids = output_ids[output_ids != generation_config.pad_token_id].reshape( + output_ids = output_ids[output_ids != generation_config._pad_token_tensor].reshape( batch_size, self.decoder.num_codebooks, -1 ) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index cab6fe8d094cd6..b21183618897c6 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3196,6 +3196,40 @@ def test_assisted_decoding_in_gpu_cpu(self): ) self.assertTrue(input_length <= out.shape[-1] <= input_length + 20) + def test_special_tokens_fall_back_to_model_default(self): + # PT-only test: TF doesn't support assisted decoding yet. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( + torch_device + ) + test_bos_id = 50 + + # Sanity-check: the model has a BOS token set, and the first generated token is a BOS token + gen_output = model.generate() + self.assertTrue(model.generation_config.bos_token_id is not None) + self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0]) + + # If we pass a generation config **with** a BOS token, `generate` will use it + generation_config = GenerationConfig(bos_token_id=test_bos_id) + gen_output = model.generate(generation_config=generation_config) + self.assertFalse(model.generation_config.bos_token_id == gen_output[0, 0]) + self.assertTrue(generation_config.bos_token_id == gen_output[0, 0]) + self.assertTrue(test_bos_id == gen_output[0, 0]) + + # If we pass a generation config **without** a BOS token, `generate` will fetch the BOS token from + # `model.generation_config` + generation_config = GenerationConfig(bos_token_id=None) + gen_output = model.generate(generation_config=generation_config) + self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0]) + self.assertFalse(test_bos_id == gen_output[0, 0]) + self.assertTrue(generation_config.bos_token_id is None) + + # Changing `model.generation_config` will affect fallback behavior + model.generation_config.bos_token_id = test_bos_id + gen_output = model.generate(generation_config=generation_config) + self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0]) + self.assertTrue(test_bos_id == gen_output[0, 0]) + self.assertTrue(generation_config.bos_token_id is None) + @require_torch class TokenHealingTestCase(unittest.TestCase):