Skip to content

Commit

Permalink
Add temperature rescaling to the multinomial sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Feb 12, 2024
1 parent 7ce7d28 commit 29bd1fe
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 4 deletions.
40 changes: 36 additions & 4 deletions outlines/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,18 @@ def __init__(
*,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
):
self.samples = samples

self.logits_processor = lambda x: x
self.logits_processors = []
if top_k is not None:
self.logits_processor = keep_top_k_logits(top_k)
self.logits_processors.append(keep_top_k_logits(top_k))
elif top_p is not None:
self.logits_processor = keep_top_p_logits(top_p)
self.logits_processors.append(keep_top_p_logits(top_p))

if temperature is not None:
self.logits_processors.append(rescale_logits(temperature))

def __call__(
self,
Expand Down Expand Up @@ -132,7 +136,10 @@ def __call__(
cumulative weights of each sequence of shape ``(n_seqs,)``.
"""
altered_next_token_logits = self.logits_processor(next_token_logits)
altered_next_token_logits = next_token_logits
for logit_processor in self.logits_processors:
altered_next_token_logits = logit_processor(next_token_logits)

probs = torch.nn.functional.softmax(altered_next_token_logits, dim=-1)
next_token_ids = torch.multinomial(probs, num_samples=1, generator=rng)

Expand Down Expand Up @@ -196,6 +203,31 @@ def logits_processor(logits: torch.Tensor) -> torch.Tensor:
return logits_processor


def rescale_logits(temperature: float) -> Callable[[torch.Tensor], torch.Tensor]:
"""Build a function that rescales the token probabilities exponentially.
Parameters
----------
temperature
The value by which we rescale the logits.
"""

if not isinstance(temperature, float) or temperature < 0.0:
raise ValueError(
f"`temperature` must be a strictly negative floating point number, got {temperature} instead."
)
elif temperature == 0.0:
raise ValueError(
"Please use the greedy sampler instead of setting the temperature to 0."
)

def logits_processor(logits: torch.Tensor) -> torch.Tensor:
return logits / temperature

return logits_processor


class BeamSearchSampler:
"""Beam Search sampling algorithm.
Expand Down
38 changes: 38 additions & 0 deletions tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
keep_top_k_logits,
keep_top_p_logits,
multinomial,
rescale_logits,
)


Expand Down Expand Up @@ -72,6 +73,32 @@ def test_multinomial():
assert weights.equal(torch.tensor([logprobs[0, 0], logprobs[1, 2]]))


def test_multinomial_init():
sampler = MultinomialSampler()
assert sampler.logits_processors == []

sampler = MultinomialSampler(3)
assert sampler.logits_processors == []

sampler = MultinomialSampler(top_k=1)
assert len(sampler.logits_processors) == 1

sampler = MultinomialSampler(top_p=0.95)
assert len(sampler.logits_processors) == 1

sampler = MultinomialSampler(top_k=1, top_p=0.95)
assert len(sampler.logits_processors) == 1

sampler = MultinomialSampler(temperature=1.0)
assert len(sampler.logits_processors) == 1

sampler = MultinomialSampler(top_k=1, temperature=1.0)
assert len(sampler.logits_processors) == 2

sampler = MultinomialSampler(top_p=0.95, temperature=1.0)
assert len(sampler.logits_processors) == 2


def test_top_k():
with pytest.raises(ValueError, match="`k` must be a strictly"):
keep_top_k_logits(-1)
Expand Down Expand Up @@ -159,6 +186,17 @@ def test_top_p():
)


def test_rescale():
with pytest.raises(ValueError, match="`temperature` must"):
rescale_logits(1)

with pytest.raises(ValueError, match="`temperature` must"):
rescale_logits(-0.1)

with pytest.raises(ValueError, match="Please use the greedy sampler"):
rescale_logits(0.0)


def test_beam_search():
# Two beams, single sequence
sampler = BeamSearchSampler(2)
Expand Down

0 comments on commit 29bd1fe

Please sign in to comment.