Skip to content

Commit

Permalink
Generate: store special token tensors under a unique variable name (#…
Browse files Browse the repository at this point in the history
…31980)

* rename stuff

* english; this one shouldn't be changed

* add a _ to the new var names

* musicgen

* derp
  • Loading branch information
gante authored Jul 22, 2024
1 parent aa8f86a commit c38c55f
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 273 deletions.
134 changes: 73 additions & 61 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,12 +754,12 @@ def _get_logits_warper(
warpers = LogitsProcessorList()

# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
# better score (i.e. keep len(list(generation_config.eos_token_id)) + 1)
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
if generation_config.num_beams > 1:
if isinstance(generation_config.eos_token_id, list):
min_tokens_to_keep = len(generation_config.eos_token_id) + 1
elif isinstance(generation_config.eos_token_id, torch.Tensor):
min_tokens_to_keep = generation_config.eos_token_id.shape[0] + 1
if isinstance(generation_config._eos_token_tensor, list):
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
else:
min_tokens_to_keep = 2
else:
Expand Down Expand Up @@ -863,31 +863,31 @@ def _get_logits_processor(
processors.append(
NoBadWordsLogitsProcessor(
generation_config.bad_words_ids,
generation_config.eos_token_id,
generation_config._eos_token_tensor,
)
)
if (
generation_config.min_length is not None
and generation_config.eos_token_id is not None
and generation_config._eos_token_tensor is not None
and generation_config.min_length > 0
):
processors.append(
MinLengthLogitsProcessor(
generation_config.min_length,
generation_config.eos_token_id,
generation_config._eos_token_tensor,
device=device,
)
)
if (
generation_config.min_new_tokens is not None
and generation_config.eos_token_id is not None
and generation_config._eos_token_tensor is not None
and generation_config.min_new_tokens > 0
):
processors.append(
MinNewTokensLengthLogitsProcessor(
input_ids_seq_length,
generation_config.min_new_tokens,
generation_config.eos_token_id,
generation_config._eos_token_tensor,
device=device,
)
)
Expand Down Expand Up @@ -918,7 +918,7 @@ def _get_logits_processor(
processors.append(
ExponentialDecayLengthPenalty(
generation_config.exponential_decay_length_penalty,
generation_config.eos_token_id,
generation_config._eos_token_tensor,
input_ids_seq_length,
)
)
Expand Down Expand Up @@ -997,8 +997,8 @@ def _get_stopping_criteria(
"stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`."
)
criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
if generation_config.eos_token_id is not None:
criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id))
if generation_config._eos_token_tensor is not None:
criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
return criteria

Expand Down Expand Up @@ -1349,13 +1349,15 @@ def _prepare_generation_config(
self, generation_config: Optional[GenerationConfig], **kwargs: Dict
) -> Tuple[GenerationConfig, Dict]:
"""
Prepares the base generation config, then applies any generation configuration options from kwargs.
Prepares the base generation config, then applies any generation configuration options from kwargs. This
function handles retrocompatibility with respect to configuration files.
"""
# TODO joao: when we can detect `fullgraph=True` in `torch.compile` (https://github.com/pytorch/pytorch/pull/120400)
# replace `is_torchdynamo_compiling` by the corresponding check. As it is, we are being too restrictive with
# the parameterization in `fullgraph=False` so as to enable `fullgraph=True`.

# priority: `generation_config` argument > `model.generation_config` (the default generation config)
using_model_generation_config = False
if generation_config is None:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# three conditions must be met
Expand All @@ -1378,6 +1380,7 @@ def _prepare_generation_config(
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
)
self.generation_config = new_generation_config
using_model_generation_config = True
generation_config = self.generation_config

# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
Expand All @@ -1395,6 +1398,16 @@ def _prepare_generation_config(
else:
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
if not using_model_generation_config:
if generation_config.bos_token_id is None:
generation_config.bos_token_id = self.generation_config.bos_token_id
if generation_config.eos_token_id is None:
generation_config.eos_token_id = self.generation_config.eos_token_id
if generation_config.pad_token_id is None:
generation_config.pad_token_id = self.generation_config.pad_token_id
if generation_config.decoder_start_token_id is None:
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id

return generation_config, model_kwargs

Expand Down Expand Up @@ -1493,52 +1506,46 @@ def _prepare_special_tokens(
function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
"""

# Convert special tokens to tensors (if they exist either in kwargs or in self.config)
def _tensor_or_none(token_kwargs, token_self, device=None):
if device is None:
device = self.device

token = token_kwargs if token_kwargs is not None else token_self
# Convert special tokens to tensors
def _tensor_or_none(token, device=None):
if token is None:
return token
elif isinstance(token, torch.Tensor):
return token.to(device)

device = device if device is not None else self.device
if isinstance(token, torch.Tensor):
return token.to(device)
return torch.tensor(token, device=device, dtype=torch.long)

bos_token_id = _tensor_or_none(
generation_config.bos_token_id, self.generation_config.bos_token_id, device=device
)
eos_token_id = _tensor_or_none(
generation_config.eos_token_id, self.generation_config.eos_token_id, device=device
)
pad_token_id = _tensor_or_none(
generation_config.pad_token_id, self.generation_config.pad_token_id, device=device
)
decoder_start_token_id = _tensor_or_none(
generation_config.decoder_start_token_id, self.generation_config.decoder_start_token_id, device=device
)
bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
decoder_start_token_tensor = _tensor_or_none(generation_config.decoder_start_token_id, device=device)

# for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892)
if self.config.is_encoder_decoder:
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
decoder_start_token_tensor = (
decoder_start_token_tensor if decoder_start_token_tensor is not None else bos_token_tensor
)

# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
if eos_token_id is not None and eos_token_id.ndim == 0:
eos_token_id = eos_token_id.unsqueeze(0)
if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
eos_token_tensor = eos_token_tensor.unsqueeze(0)

# Set pad token if unset (and there are conditions to do so)
if pad_token_id is None and eos_token_id is not None:
if pad_token_tensor is None and eos_token_tensor is not None:
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
pad_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")
pad_token_tensor = eos_token_tensor[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")

# we can't infer attn mask if pad token is set to be eos token in model's generation config
if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any():
if (
eos_token_tensor is not None
and torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
):
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
Expand All @@ -1547,21 +1554,26 @@ def _tensor_or_none(token_kwargs, token_self, device=None):
)

# Sanity checks/warnings
if self.config.is_encoder_decoder and decoder_start_token_id is None:
if self.config.is_encoder_decoder and decoder_start_token_tensor is None:
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
if eos_token_id is not None and (torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any()):
if eos_token_tensor is not None and (
torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any()
):
logger.warning(
f"`eos_token_id` should consist of positive integers, but is {eos_token_id}. Your generation will not "
f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation will not "
"stop until the maximum length is reached. Depending on other flags, it may even crash."
)

# Update generation config with the updated special tokens tensors
generation_config.bos_token_id = bos_token_id
generation_config.eos_token_id = eos_token_id
generation_config.pad_token_id = pad_token_id
generation_config.decoder_start_token_id = decoder_start_token_id
# NOTE: this must be written into a different attribute name than the one holding the original special tokens
# (in their non-tensor form), in order to enable end-to-end compilation. See
# https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
generation_config._bos_token_tensor = bos_token_tensor
generation_config._eos_token_tensor = eos_token_tensor
generation_config._pad_token_tensor = pad_token_tensor
generation_config._decoder_start_token_tensor = decoder_start_token_tensor

@torch.no_grad()
def generate(
Expand Down Expand Up @@ -1696,10 +1708,10 @@ def generate(
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
if (
generation_config.pad_token_id is not None
generation_config._pad_token_tensor is not None
and batch_size > 1
and len(inputs_tensor.shape) == 2
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
Expand All @@ -1716,7 +1728,7 @@ def generate(

if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor
)

if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
Expand All @@ -1731,7 +1743,7 @@ def generate(
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id,
decoder_start_token_id=generation_config._decoder_start_token_tensor,
device=inputs_tensor.device,
)
else:
Expand Down Expand Up @@ -2279,7 +2291,7 @@ def _dola_decoding(
raise ValueError("DoLa decoding is only available for decoder-only models.")
# init values

pad_token_id = generation_config.pad_token_id
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
Expand Down Expand Up @@ -2486,7 +2498,7 @@ def _contrastive_search(
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
top_k = generation_config.top_k
penalty_alpha = generation_config.penalty_alpha
pad_token_id = generation_config.pad_token_id
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
Expand Down Expand Up @@ -2877,7 +2889,7 @@ def _sample(
`model.config.is_encoder_decoder=True`.
"""
# init values
pad_token_id = generation_config.pad_token_id
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
Expand Down Expand Up @@ -3084,8 +3096,8 @@ def _beam_search(
`model.config.is_encoder_decoder=True`.
"""
# init values
pad_token_id = generation_config.pad_token_id
eos_token_id = generation_config.eos_token_id
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
Expand Down Expand Up @@ -3366,8 +3378,8 @@ def _group_beam_search(
`model.config.is_encoder_decoder=True`.
"""
# init values
pad_token_id = generation_config.pad_token_id
eos_token_id = generation_config.eos_token_id
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
Expand Down Expand Up @@ -3658,8 +3670,8 @@ def _constrained_beam_search(
`model.config.is_encoder_decoder=True`.
"""
# init values
pad_token_id = generation_config.pad_token_id
eos_token_id = generation_config.eos_token_id
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
Expand Down
Loading

0 comments on commit c38c55f

Please sign in to comment.