Skip to content

Commit

Permalink
Merge pull request #244 from meher-m/transformers_fix
Browse files Browse the repository at this point in the history
Adding support for transformers>=4.40.2 to avoid crash with mbpp
  • Loading branch information
loubnabnl authored Jun 24, 2024
2 parents 4659ecd + cc9033c commit 334efb7
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions bigcode_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,22 @@ def complete_code(
**gen_kwargs,
)
else:
generated_tokens = model.generate(
input_ids=inputs,
num_return_sequences=batch_size,
**gen_kwargs,
)
# In transformers (>= 4.40.2), if the length of input_ids == max_length, a ValueError is thrown.
# We want to ignore this error in order to reproduce old results with mbpp.
try:
generated_tokens = model.generate(
input_ids=inputs,
num_return_sequences=batch_size,
**gen_kwargs,
)
except ValueError as e:
# When the length of input_ids == max_length, the generation is the same as the input
if str(e).startswith(f"Input length of input_ids is {inputs.shape[1]}, but `max_length` is set to {gen_kwargs['max_length']}"):
warnings.warn(f"An error with the following message was thrown: {e}. Returning the input as the generation, for higher scores consider using a larger `max_length`")
generated_tokens = inputs
else:
raise e

# each task is generated batch_size times
generated_tasks = batch["task_id"].repeat(batch_size)
generated_tokens = accelerator.pad_across_processes(
Expand Down

0 comments on commit 334efb7

Please sign in to comment.