Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: store special token tensors under a unique variable name #31980

Merged
merged 5 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 73 additions & 61 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,12 +748,12 @@ def _get_logits_warper(
warpers = LogitsProcessorList()

# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
# better score (i.e. keep len(list(generation_config.eos_token_id)) + 1)
# better score (i.e. keep len(list(generation_config.eos_token_tensor)) + 1)
if generation_config.num_beams > 1:
if isinstance(generation_config.eos_token_id, list):
min_tokens_to_keep = len(generation_config.eos_token_id) + 1
elif isinstance(generation_config.eos_token_id, torch.Tensor):
min_tokens_to_keep = generation_config.eos_token_id.shape[0] + 1
if isinstance(generation_config.eos_token_tensor, list):
min_tokens_to_keep = len(generation_config.eos_token_tensor) + 1
elif isinstance(generation_config.eos_token_tensor, torch.Tensor):
min_tokens_to_keep = generation_config.eos_token_tensor.shape[0] + 1
else:
min_tokens_to_keep = 2
else:
Expand Down Expand Up @@ -857,31 +857,31 @@ def _get_logits_processor(
processors.append(
NoBadWordsLogitsProcessor(
generation_config.bad_words_ids,
generation_config.eos_token_id,
generation_config.eos_token_tensor,
)
)
if (
generation_config.min_length is not None
and generation_config.eos_token_id is not None
and generation_config.eos_token_tensor is not None
and generation_config.min_length > 0
):
processors.append(
MinLengthLogitsProcessor(
generation_config.min_length,
generation_config.eos_token_id,
generation_config.eos_token_tensor,
device=device,
)
)
if (
generation_config.min_new_tokens is not None
and generation_config.eos_token_id is not None
and generation_config.eos_token_tensor is not None
and generation_config.min_new_tokens > 0
):
processors.append(
MinNewTokensLengthLogitsProcessor(
input_ids_seq_length,
generation_config.min_new_tokens,
generation_config.eos_token_id,
generation_config.eos_token_tensor,
device=device,
)
)
Expand Down Expand Up @@ -912,7 +912,7 @@ def _get_logits_processor(
processors.append(
ExponentialDecayLengthPenalty(
generation_config.exponential_decay_length_penalty,
generation_config.eos_token_id,
generation_config.eos_token_tensor,
input_ids_seq_length,
)
)
Expand Down Expand Up @@ -991,8 +991,8 @@ def _get_stopping_criteria(
"stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`."
)
criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer))
if generation_config.eos_token_id is not None:
criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id))
if generation_config.eos_token_tensor is not None:
criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_tensor))
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
return criteria

Expand Down Expand Up @@ -1343,13 +1343,15 @@ def _prepare_generation_config(
self, generation_config: Optional[GenerationConfig], **kwargs: Dict
) -> Tuple[GenerationConfig, Dict]:
"""
Prepares the base generation config, then applies any generation configuration options from kwargs.
Prepares the base generation config, then applies any generation configuration options from kwargs. This
function handles retrocompatibility with respect to configuration files.
"""
# TODO joao: when we can detect `fullgraph=True` in `torch.compile` (https://github.com/pytorch/pytorch/pull/120400)
# replace `is_torchdynamo_compiling` by the corresponding check. As it is, we are being too restrictive with
# the parameterization in `fullgraph=False` so as to enable `fullgraph=True`.

# priority: `generation_config` argument > `model.generation_config` (the default generation config)
using_model_generation_config = False
if generation_config is None:
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
# three conditions must be met
Expand All @@ -1372,6 +1374,7 @@ def _prepare_generation_config(
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
)
self.generation_config = new_generation_config
using_model_generation_config = True
generation_config = self.generation_config

# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
Expand All @@ -1389,6 +1392,16 @@ def _prepare_generation_config(
else:
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
if not using_model_generation_config:
if generation_config.bos_token_id is None:
generation_config.bos_token_id = self.generation_config.bos_token_id
if generation_config.eos_token_id is None:
generation_config.eos_token_id = self.generation_config.eos_token_id
if generation_config.pad_token_id is None:
generation_config.pad_token_id = self.generation_config.pad_token_id
if generation_config.decoder_start_token_id is None:
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
Comment on lines +1395 to +1404
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is equivalent to the changes in this PR, which are better suited to this function -- handling retrocompatibility wrt config files


return generation_config, model_kwargs

Expand Down Expand Up @@ -1486,52 +1499,46 @@ def _prepare_special_tokens(
function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
"""

# Convert special tokens to tensors (if they exist either in kwargs or in self.config)
def _tensor_or_none(token_kwargs, token_self, device=None):
if device is None:
device = self.device

token = token_kwargs if token_kwargs is not None else token_self
# Convert special tokens to tensors
def _tensor_or_none(token, device=None):
if token is None:
return token
elif isinstance(token, torch.Tensor):
return token.to(device)

device = device if device is not None else self.device
if isinstance(token, torch.Tensor):
return token.to(device)
return torch.tensor(token, device=device, dtype=torch.long)

bos_token_id = _tensor_or_none(
generation_config.bos_token_id, self.generation_config.bos_token_id, device=device
)
eos_token_id = _tensor_or_none(
generation_config.eos_token_id, self.generation_config.eos_token_id, device=device
)
pad_token_id = _tensor_or_none(
generation_config.pad_token_id, self.generation_config.pad_token_id, device=device
)
decoder_start_token_id = _tensor_or_none(
generation_config.decoder_start_token_id, self.generation_config.decoder_start_token_id, device=device
)
bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
decoder_start_token_tensor = _tensor_or_none(generation_config.decoder_start_token_id, device=device)

# for BC we also try to get `decoder_start_token_id` or `bos_token_id` (#30892)
if self.config.is_encoder_decoder:
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
decoder_start_token_tensor = (
decoder_start_token_tensor if decoder_start_token_tensor is not None else bos_token_tensor
)

# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
if eos_token_id is not None and eos_token_id.ndim == 0:
eos_token_id = eos_token_id.unsqueeze(0)
if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
eos_token_tensor = eos_token_tensor.unsqueeze(0)

# Set pad token if unset (and there are conditions to do so)
if pad_token_id is None and eos_token_id is not None:
if pad_token_tensor is None and eos_token_tensor is not None:
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
pad_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")
pad_token_tensor = eos_token_tensor[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")

# we can't infer attn mask if pad token is set to be eos token in model's generation config
if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any():
if (
eos_token_tensor is not None
and torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
):
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
Expand All @@ -1540,21 +1547,26 @@ def _tensor_or_none(token_kwargs, token_self, device=None):
)

# Sanity checks/warnings
if self.config.is_encoder_decoder and decoder_start_token_id is None:
if self.config.is_encoder_decoder and decoder_start_token_tensor is None:
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
if eos_token_id is not None and (torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any()):
if eos_token_tensor is not None and (
torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any()
):
logger.warning(
f"`eos_token_id` should consist of positive integers, but is {eos_token_id}. Your generation will not "
f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation will not "
"stop until the maximum length is reached. Depending on other flags, it may even crash."
)

# Update generation config with the updated special tokens tensors
generation_config.bos_token_id = bos_token_id
generation_config.eos_token_id = eos_token_id
generation_config.pad_token_id = pad_token_id
generation_config.decoder_start_token_id = decoder_start_token_id
# NOTE: this must be written into a different attribute name than the one holding the original special tokens
# (in their non-tensor form), in order to enable end-to-end compilation. See
# https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
generation_config.bos_token_tensor = bos_token_tensor
generation_config.eos_token_tensor = eos_token_tensor
generation_config.pad_token_tensor = eos_token_tensor
generation_config.decoder_start_token_tensor = decoder_start_token_tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it work to have a property, eos_token_id with _eos_token_tensor underlying? When you set it you cast to tensor format. might be simpler in general ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uhmmm @property will further tangle us with state/python classes, which I'm not a fan of for compile purposes 🤔

I am going to rename the tensor variables from xxx_token_tensor to _xxx_token_tensor though, to help with readability!


@torch.no_grad()
def generate(
Expand Down Expand Up @@ -1689,10 +1701,10 @@ def generate(
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
if (
generation_config.pad_token_id is not None
generation_config.pad_token_tensor is not None
and batch_size > 1
and len(inputs_tensor.shape) == 2
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_tensor) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
Expand All @@ -1709,7 +1721,7 @@ def generate(

if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
inputs_tensor, generation_config.pad_token_tensor, generation_config.eos_token_tensor
)

if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
Expand All @@ -1724,7 +1736,7 @@ def generate(
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id,
decoder_start_token_id=generation_config.decoder_start_token_tensor,
device=inputs_tensor.device,
)
else:
Expand Down Expand Up @@ -2268,7 +2280,7 @@ def _dola_decoding(
raise ValueError("DoLa decoding is only available for decoder-only models.")
# init values

pad_token_id = generation_config.pad_token_id
pad_token_id = generation_config.pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
Expand Down Expand Up @@ -2475,7 +2487,7 @@ def _contrastive_search(
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
top_k = generation_config.top_k
penalty_alpha = generation_config.penalty_alpha
pad_token_id = generation_config.pad_token_id
pad_token_id = generation_config.pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
Expand Down Expand Up @@ -2866,7 +2878,7 @@ def _sample(
`model.config.is_encoder_decoder=True`.
"""
# init values
pad_token_id = generation_config.pad_token_id
pad_token_id = generation_config.pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
Expand Down Expand Up @@ -3073,8 +3085,8 @@ def _beam_search(
`model.config.is_encoder_decoder=True`.
"""
# init values
pad_token_id = generation_config.pad_token_id
eos_token_id = generation_config.eos_token_id
pad_token_id = generation_config.pad_token_tensor
eos_token_id = generation_config.eos_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
Expand Down Expand Up @@ -3355,8 +3367,8 @@ def _group_beam_search(
`model.config.is_encoder_decoder=True`.
"""
# init values
pad_token_id = generation_config.pad_token_id
eos_token_id = generation_config.eos_token_id
pad_token_id = generation_config.pad_token_tensor
eos_token_id = generation_config.eos_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
Expand Down Expand Up @@ -3647,8 +3659,8 @@ def _constrained_beam_search(
`model.config.is_encoder_decoder=True`.
"""
# init values
pad_token_id = generation_config.pad_token_id
eos_token_id = generation_config.eos_token_id
pad_token_id = generation_config.pad_token_tensor
eos_token_id = generation_config.eos_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
Expand Down
34 changes: 34 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3196,6 +3196,40 @@ def test_assisted_decoding_in_gpu_cpu(self):
)
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)

def test_special_tokens_fall_back_to_model_default(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was missing in #31254 😛

# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
torch_device
)
test_bos_id = 50

# Sanity-check: the model has a BOS token set, and the first generated token is a BOS token
gen_output = model.generate()
self.assertTrue(model.generation_config.bos_token_id is not None)
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])

# If we pass a generation config **with** a BOS token, `generate` will use it
generation_config = GenerationConfig(bos_token_id=test_bos_id)
gen_output = model.generate(generation_config=generation_config)
self.assertFalse(model.generation_config.bos_token_id == gen_output[0, 0])
self.assertTrue(generation_config.bos_token_id == gen_output[0, 0])
self.assertTrue(test_bos_id == gen_output[0, 0])

# If we pass a generation config **without** a BOS token, `generate` will fetch the BOS token from
# `model.generation_config`
generation_config = GenerationConfig(bos_token_id=None)
gen_output = model.generate(generation_config=generation_config)
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
self.assertFalse(test_bos_id == gen_output[0, 0])
self.assertTrue(generation_config.bos_token_id is None)

# Changing `model.generation_config` will affect fallback behavior
model.generation_config.bos_token_id = test_bos_id
gen_output = model.generate(generation_config=generation_config)
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
self.assertTrue(test_bos_id == gen_output[0, 0])
self.assertTrue(generation_config.bos_token_id is None)


@require_torch
class TokenHealingTestCase(unittest.TestCase):
Expand Down
Loading