Skip to content

Commit

Permalink
[test_all] final test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed May 3, 2024
1 parent 4e493f0 commit f035f0b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 33 deletions.
40 changes: 24 additions & 16 deletions src/transformers/generation/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def process(
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
group_index: Optional[int] = 0,
decoder_prompt_len: Optional[int] = 0,
Expand All @@ -245,8 +245,10 @@ def process(
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

for batch_idx in range(batch_size):
batch_group_idx = batch_idx * self.num_beam_groups + group_index
Expand Down Expand Up @@ -322,15 +324,17 @@ def finalize(
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps) // self.num_beam_groups

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

# finalize all open beam hypotheses and add to generated hypotheses
for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
Expand Down Expand Up @@ -513,8 +517,8 @@ def process(
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
scores_for_all_vocab: torch.FloatTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.Tensor]:
Expand Down Expand Up @@ -578,8 +582,10 @@ def process(
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
Expand Down Expand Up @@ -811,15 +817,17 @@ def finalize(
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps)

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)

# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -2226,6 +2226,7 @@ def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p:
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
self.eos_token_id = eos_token_id

if torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any():
logger.warning_once(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
Expand All @@ -2237,6 +2238,7 @@ def __init__(self, eos_token_id: Union[int, List[int], torch.Tensor], min_eos_p:
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores_processed = scores
self.eos_token_id = self.eos_token_id.to(scores.device)
if self.min_eos_p:
probs = torch.nn.functional.softmax(scores.float(), dim=-1)
# create scores full of -inf except for the eos_token_id
Expand Down
26 changes: 9 additions & 17 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,13 +1333,17 @@ def _prepare_special_tokens(
"""
Prepares the special tokens for generation, overwriting the generation config with their processed versions
converted to tensor.
Note that `generation_config` is changed in place and stops being serializable after this method is called.
If called outside `generate`, consider creating a copy of `generation_config` first.
That is no problem is callen within `generate` (`generation_config` is a local copy that doesn't leave the
function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
"""

# Convert special tokens to tensors (if they exist)
def _tensor_or_none(token):
return torch.tensor(token, device=self.device, dtype=torch.long) if token is not None else None
if token is None or isinstance(token, torch.Tensor):
return token
return torch.tensor(token, device=self.device, dtype=torch.long)

bos_token_id = _tensor_or_none(generation_config.bos_token_id)
eos_token_id = _tensor_or_none(generation_config.eos_token_id)
Expand Down Expand Up @@ -2644,9 +2648,6 @@ def _beam_search(
return_dict_in_generate = generation_config.return_dict_in_generate
sequential = generation_config.low_memory

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams

Expand Down Expand Up @@ -2770,7 +2771,7 @@ def _beam_search(
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
n_eos_tokens = len(eos_token_id) if eos_token_id else 0
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
next_token_scores, next_tokens = torch.topk(
next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True
)
Expand Down Expand Up @@ -2914,9 +2915,6 @@ def _beam_sample(
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams

Expand Down Expand Up @@ -3137,9 +3135,6 @@ def _group_beam_search(
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
Expand Down Expand Up @@ -3242,7 +3237,7 @@ def _group_beam_search(
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)

# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
n_eos_tokens = len(eos_token_id) if eos_token_id else 0
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
next_token_scores, next_tokens = torch.topk(
next_token_scores, max(2, 1 + n_eos_tokens) * group_size, dim=1, largest=True, sorted=True
)
Expand Down Expand Up @@ -3422,9 +3417,6 @@ def _constrained_beam_search(
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]

batch_size = len(constrained_beam_scorer._beam_hyps)
num_beams = constrained_beam_scorer.num_beams

Expand Down Expand Up @@ -3514,7 +3506,7 @@ def _constrained_beam_search(
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
n_eos_tokens = len(eos_token_id) if eos_token_id else 0
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
next_token_scores, next_tokens = torch.topk(
next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True
)
Expand Down

0 comments on commit f035f0b

Please sign in to comment.