Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow configurable sampling steps #318

Merged
merged 2 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions outlines/text/generate/continuation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union

import torch

from outlines.text.generate.sequence import Sequence

if TYPE_CHECKING:
from outlines.text.generate.sample import Sampler


class Continuation(Sequence):
"""Represents a completion generation model.
Expand All @@ -18,9 +21,13 @@ class Continuation(Sequence):
"""

def __init__(
self, model, max_tokens: Optional[int] = None, stop: Union[str, List[str]] = []
self,
model,
max_tokens: Optional[int] = None,
sampler: Optional["Sampler"] = None,
stop: Union[str, List[str]] = [],
):
super().__init__(model, max_tokens)
super().__init__(model, max_tokens, sampler)
self.eos_token_id = torch.tensor(
[self.model.tokenizer.eos_token_id], device=self.device
)
Expand Down Expand Up @@ -89,7 +96,11 @@ def postprocess_completions(self, completions: List[str]) -> List[str]:


def continuation(
model, max_tokens: Optional[int] = None, *, stop: Union[str, List[str]] = []
model,
max_tokens: Optional[int] = None,
*,
sampler: Optional["Sampler"] = None,
stop: Union[str, List[str]] = [],
):
"""Generate text sequences.

Expand All @@ -99,9 +110,14 @@ def continuation(
The language model to use to compute the next-token logits.
max_tokens
The maximum number of tokens to generate.
sampler
The function used to draw samples. Defaults to
`outlines.text.generate.sample.multinomial`. See
`outlines.text.generate.sample.Sampler` for the expected form of
such functions.
stop
A string or list of strings which, when generated, stops
the generation for this sequence.

"""
return Continuation(model, max_tokens, stop)
return Continuation(model, max_tokens, sampler, stop)
124 changes: 111 additions & 13 deletions outlines/text/generate/regex.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from json import dumps
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union

import interegular
import torch
Expand All @@ -10,17 +10,28 @@
from outlines.text.generate.continuation import Continuation
from outlines.text.json_schema import build_regex_from_schema

if TYPE_CHECKING:
from outlines.text.generate.sample import Sampler


class Regex(Continuation):
"""Represents a regex-based generation model.

`Regex` instances are constrained generation models that only generate
sequences that match an input regex. We assume that the sequence can be
terminated (but not necessarily) when the finite state machine corresponding
to the regex is in an accepting state.
sequences matching a given regex.

>>> import outlines.text as text
>>> sequence = text.generate.regex(model, "(0|[1-9][0-9]+)")("Return an integer between 0 and 10")
>>> generator = text.generate.regex(model, "(0|[1-9][0-9]+)")

Sequences can then be generated from a prompt as follows:

>>> sequence_1 = generator("Return an integer between 0 and 10")
>>> sequence_2 = generator("Rate the movie "Hackers" on a scale from 0 to 10")

.. note:
Reuse instances of these guided generators (e.g. `generator` from the
above example) whenever possible, because constructing them has more
overhead than generating token sequences from them.

"""

Expand All @@ -29,6 +40,7 @@ def __init__(
model,
regex_string: str,
max_tokens: Optional[int] = None,
sampler: Optional["Sampler"] = None,
allow_empty_tokens: bool = True,
initial_state: Optional[int] = None,
final_states: Optional[Set[int]] = None,
Expand All @@ -39,10 +51,17 @@ def __init__(

Parameters
----------
model
The instance of the model used to generate next-token probabilities.
regex_string
The regex with which the token sampling process is guided/constrained.
max_tokens
The maximum number of tokens to be sampled.
sampler
The function used to draw samples. Defaults to
`outlines.text.generate.sample.multinomial`. See
`outlines.text.generate.sample.Sampler` for the expected form of
such functions.
allow_empty_tokens
Allow sampling of tokens corresponding to empty strings.
states_to_token_maps
Expand All @@ -52,7 +71,7 @@ def __init__(
Pre-computed set of token ids for tokens that are empty strings.

"""
super().__init__(model, max_tokens)
super().__init__(model, max_tokens, sampler)

if (
states_to_token_maps is None
Expand Down Expand Up @@ -201,10 +220,17 @@ def regex(
model,
regex_string: str,
max_tokens: Optional[int] = None,
*,
sampler: Optional["Sampler"] = None,
allow_empty_tokens: bool = True,
):
"""Generate text sequences that match the input regex.

.. note:
Reuse instances of these guided generators whenever possible,
because constructing them has more overhead than generating
token sequences from them. See the docstring for `Regex`.

Parameters
----------
model
Expand All @@ -213,46 +239,83 @@ def regex(
The regular expression that generated expressions must match.
max_tokens
The maximum number of tokens to generate.
sampler
The function used to draw samples. Defaults to
`outlines.text.generate.sample.multinomial`. See
`outlines.text.generate.sample.Sampler` for the expected form of
such functions.
allow_empty_tokens
Allow sampling of tokens corresponding to empty strings.

"""
return Regex(model, regex_string, max_tokens, allow_empty_tokens)
return Regex(model, regex_string, max_tokens, sampler, allow_empty_tokens)


def integer(model, max_tokens: Optional[int] = None, allow_empty_tokens: bool = True):
def integer(
model,
max_tokens: Optional[int] = None,
*,
sampler: Optional["Sampler"] = None,
allow_empty_tokens: bool = True,
):
"""Generate integers.

The regex used to constrain the generation optionally matches plus or minus
signs and forbids leading zeros (even if the `int` function in Python allows
them).

.. note:
Reuse instances of these guided generators whenever possible,
because constructing them has more overhead than generating
token sequences from them. See the docstring for `Regex`.

Parameters
----------
model
The language model to use to compute the next-token logits.
max_tokens
The maximum number of tokens to generate.
sampler
The function used to draw samples. Defaults to
`outlines.text.generate.sample.multinomial`. See
`outlines.text.generate.sample.Sampler` for the expected form of
such functions.
allow_empty_tokens
Allow sampling of tokens corresponding to empty strings.

"""
return Regex(model, r"[-+]?\d+", max_tokens, allow_empty_tokens)
return Regex(model, r"[-+]?\d+", max_tokens, sampler, allow_empty_tokens)


def float(model, max_tokens: Optional[int] = None, allow_empty_tokens: bool = True):
def float(
model,
max_tokens: Optional[int] = None,
*,
sampler: Optional["Sampler"] = None,
allow_empty_tokens: bool = True,
):
"""Generate floating-point numbers.

The regex used to constrain the generation optionally matches plus or minus
signs, and forbids leading zeros (even if the `float` function in Python
allows them).

.. note:
Reuse instances of these guided generators whenever possible,
because constructing them has more overhead than generating
token sequences from them. See the docstring for `Regex`.

Parameters
----------
model
The language model to use to compute the next-token logits.
max_tokens
The maximum number of tokens to generate.
sampler
The function used to draw samples. Defaults to
`outlines.text.generate.sample.multinomial`. See
`outlines.text.generate.sample.Sampler` for the expected form of
such functions.
allow_empty_tokens
Allow sampling of tokens corresponding to empty strings.

Expand All @@ -261,6 +324,7 @@ def float(model, max_tokens: Optional[int] = None, allow_empty_tokens: bool = Tr
model,
r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))",
max_tokens,
sampler,
allow_empty_tokens,
)

Expand All @@ -269,21 +333,50 @@ def choice(
model,
choices: List[str],
max_tokens: Optional[int] = None,
*,
sampler: Optional["Sampler"] = None,
allow_empty_tokens: bool = True,
):
"""Choose between different sequences."""
"""Choose between different sequences.

.. note:
Reuse instances of these guided generators whenever possible,
because constructing them has more overhead than generating
token sequences from them. See the docstring for `Regex`.

Parameters
----------
model
The language model to use to compute the next-token logits.
max_tokens
The maximum number of tokens to generate.
sampler
The function used to draw samples. Defaults to
`outlines.text.generate.sample.multinomial`. See
`outlines.text.generate.sample.Sampler` for the expected form of
such functions.
allow_empty_tokens
Allow sampling of tokens corresponding to empty strings.
"""
regex_str = r"(" + r"|".join(choices) + r")"
return Regex(model, regex_str, max_tokens, allow_empty_tokens)
return Regex(model, regex_str, max_tokens, sampler, allow_empty_tokens)


def json(
model,
schema: Union[str, BaseModel],
max_tokens: Optional[int] = None,
*,
sampler: Optional["Sampler"] = None,
allow_empty_tokens: bool = True,
):
"""Generate a text sequence that follows a JSON schema or Pydantic model.

.. note:
Reuse instances of these guided generators whenever possible,
because constructing them has more overhead than generating
token sequences from them. See the docstring for `Regex`.

Parameters
---------
model
Expand All @@ -292,6 +385,11 @@ def json(
The JSON schema or Pydantic model that guides the generation.
max_tokens
The maximum number of tokens to generate.
sampler
The function used to draw samples. Defaults to
`outlines.text.generate.sample.multinomial`. See
`outlines.text.generate.sample.Sampler` for the expected form of
such functions.
allow_empty_tokens
Allow sampling of tokens corresponding to empty strings.

Expand All @@ -301,4 +399,4 @@ def json(

regex_str = build_regex_from_schema(schema)

return Regex(model, regex_str, max_tokens, allow_empty_tokens)
return Regex(model, regex_str, max_tokens, sampler, allow_empty_tokens)
Loading
Loading