Skip to content

Commit

Permalink
Generate: nudge towards do_sample=False when temperature=0.0 (hug…
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored and blbadger committed Nov 8, 2023
1 parent 23ff568 commit 2159d31
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,13 @@ class TemperatureLogitsWarper(LogitsWarper):

def __init__(self, temperature: float):
if not isinstance(temperature, float) or not (temperature > 0):
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
except_msg = (
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
"scores will be invalid."
)
if isinstance(temperature, float) and temperature == 0.0:
except_msg += " If you're looking for greedy decoding strategies, set `do_sample=False`."
raise ValueError(except_msg)

self.temperature = temperature

Expand Down

0 comments on commit 2159d31

Please sign in to comment.