Skip to content

Commit

Permalink
Add support for empty tokens in regex guidance
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Sep 16, 2023
1 parent ccd8fb5 commit 9f416a5
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 27 deletions.
22 changes: 14 additions & 8 deletions outlines/text/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,22 +648,28 @@ def reduced_vocabulary(tokenizer: "Tokenizer"):
vocabulary = numba.typed.Dict.empty(
numba.types.string, numba.types.ListType(numba.int64)
)
empty_token_ids = set()
for token, token_idx in tokenizer.vocabulary.items():
if token in tokenizer.special_tokens:
continue

vocabulary.setdefault(
tokenizer.convert_token_to_string(token),
numba.typed.List.empty_list(numba.int64),
).append(numba.int64(token_idx))
token_str = tokenizer.convert_token_to_string(token)

return vocabulary
if token_str:
vocabulary.setdefault(
token_str,
numba.typed.List.empty_list(numba.int64),
).append(numba.int64(token_idx))
else:
empty_token_ids.add(token_idx)

return vocabulary, empty_token_ids


def create_fsm_index_tokenizer(
fsm: BetterFSM,
tokenizer: "Tokenizer",
) -> Dict[int, Dict[int, int]]:
) -> Tuple[Dict[int, Dict[int, int]], Set[int]]:
"""Construct an FMS index from a tokenizer.
This uses the end-to-end approach of `create_fsm_index_end_to_end`.
Expand All @@ -673,7 +679,7 @@ def create_fsm_index_tokenizer(
`fsm` needs to be deterministically ordered so that the caching makes sense.
"""
vocabulary = reduced_vocabulary(tokenizer)
vocabulary, empty_token_ids = reduced_vocabulary(tokenizer)

states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary)

Expand All @@ -688,4 +694,4 @@ def create_fsm_index_tokenizer(
# Convert to token-to-end-state maps
states_to_token_subsets = {k: dict(v) for k, v in states_to_token_subsets.items()}

return states_to_token_subsets
return states_to_token_subsets, empty_token_ids
94 changes: 78 additions & 16 deletions outlines/text/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,32 @@ class Regex(Continuation):
"""

def __init__(self, model, regex_string: str, max_tokens: Optional[int]):
def __init__(
self,
model,
regex_string: str,
max_tokens: Optional[int],
allow_empty_tokens: bool = True,
):
"""
Parameters
----------
regex_string
The regex with which the token sampling process is guided/constrained.
max_tokens
The maximum number of tokens to be sampled.
allow_empty_tokens
Allow sampling of tokens corresponding to empty strings.
"""
super().__init__(model, max_tokens)

self.allow_empty_tokens = allow_empty_tokens
regex_pattern = interegular.parse_pattern(regex_string)
self.regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())

self.states_to_token_maps = create_fsm_index_tokenizer(
self.states_to_token_maps, self.empty_token_ids = create_fsm_index_tokenizer(
self.regex_fsm, model.tokenizer
)

Expand Down Expand Up @@ -89,7 +108,14 @@ def create_proposal(
# Get the last token that was sampled
last_token = int(token_seq[-1])

if last_token != self.model.tokenizer.eos_token_id:
if last_token in self.empty_token_ids:
# An empty token was sampled, so the FSM state hasn't changed
next_state = last_state
mask = self._get_mask_for_state(
next_state, self.states_to_token_maps[last_state]
)

elif last_token != self.model.tokenizer.eos_token_id:
# If we previously ended with an EOS, we shouldn't be
# getting/sampling any more non-EOS tokens.
assert last_state > -1
Expand Down Expand Up @@ -141,7 +167,15 @@ def _get_mask_for_state(
-math.inf,
device=self.device,
)
mask[list(tokens_to_end_states.keys())] = 0

if self.allow_empty_tokens:
token_ids = list(self.empty_token_ids) + list(
tokens_to_end_states.keys()
)
else:
token_ids = list(tokens_to_end_states.keys())

mask[token_ids] = 0
mask = mask.unsqueeze(0)
self.mask_cache[state] = mask

Expand All @@ -152,7 +186,12 @@ def postprocess_completions(self, completions: List[str]) -> List[str]:
return super().postprocess_completions(completions)


def regex(model, regex_string: str, max_tokens: Optional[int] = None):
def regex(
model,
regex_string: str,
max_tokens: Optional[int] = None,
allow_empty_tokens: bool = True,
):
"""Generate text sequences that match the input regex.
Parameters
Expand All @@ -163,12 +202,14 @@ def regex(model, regex_string: str, max_tokens: Optional[int] = None):
The regular expression that generated expressions must match.
max_tokens
The maximum number of tokens to generate.
allow_empty_tokens
Allow sampling of tokens corresponding to empty strings.
"""
return Regex(model, regex_string, max_tokens)
return Regex(model, regex_string, max_tokens, allow_empty_tokens)


def integer(model, max_tokens: Optional[int] = None):
def integer(model, max_tokens: Optional[int] = None, allow_empty_tokens: bool = True):
"""Generate integers.
The regex used to constrain the generation optionally matches plus or minus
Expand All @@ -181,12 +222,14 @@ def integer(model, max_tokens: Optional[int] = None):
The language model to use to compute the next-token logits.
max_tokens
The maximum number of tokens to generate.
allow_empty_tokens
Allow sampling of tokens corresponding to empty strings.
"""
return Regex(model, r"[-+]?\d+", max_tokens)
return Regex(model, r"[-+]?\d+", max_tokens, allow_empty_tokens)


def float(model, max_tokens: Optional[int] = None):
def float(model, max_tokens: Optional[int] = None, allow_empty_tokens: bool = True):
"""Generate floating-point numbers.
The regex used to constrain the generation optionally matches plus or minus
Expand All @@ -199,18 +242,35 @@ def float(model, max_tokens: Optional[int] = None):
The language model to use to compute the next-token logits.
max_tokens
The maximum number of tokens to generate.
allow_empty_tokens
Allow sampling of tokens corresponding to empty strings.
"""
return Regex(model, r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))", max_tokens)


def choice(model, choices: List[str], max_tokens: Optional[int] = None):
return Regex(
model,
r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))",
max_tokens,
allow_empty_tokens,
)


def choice(
model,
choices: List[str],
max_tokens: Optional[int] = None,
allow_empty_tokens: bool = True,
):
"""Choose between different sequences."""
regex_str = r"(" + r"|".join(choices) + r")"
return Regex(model, regex_str, max_tokens)
return Regex(model, regex_str, max_tokens, allow_empty_tokens)


def json(model, schema: Union[str, BaseModel], max_tokens: Optional[int] = None):
def json(
model,
schema: Union[str, BaseModel],
max_tokens: Optional[int] = None,
allow_empty_tokens: bool = True,
):
"""Generate a text sequence that follows a JSON schema or Pydantic model.
Parameters
Expand All @@ -221,11 +281,13 @@ def json(model, schema: Union[str, BaseModel], max_tokens: Optional[int] = None)
The JSON schema or Pydantic model that guides the generation.
max_tokens
The maximum number of tokens to generate.
allow_empty_tokens
Allow sampling of tokens corresponding to empty strings.
"""
if isinstance(schema, type(BaseModel)):
schema = dumps(schema.model_json_schema())

regex_str = build_regex_from_schema(schema)

return Regex(model, regex_str, max_tokens)
return Regex(model, regex_str, max_tokens, allow_empty_tokens)
39 changes: 39 additions & 0 deletions tests/text/generate/test_regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,21 @@ def convert_token_to_string(self, token):
return token


class TokenizerWithEmpty(Tokenizer):
vocabulary = {"<EOS>": 0, "-": 1, "1": 2, "0.": 3, "431": 4, "a": 5, "A": 6, "": 7}
tokens = list(vocabulary.keys())


class Model:
tokenizer = Tokenizer()
device = "cpu"


class ModelWithEmpty:
tokenizer = TokenizerWithEmpty()
device = "cpu"


@pytest.mark.parametrize(
"regex_string, valid_first_token, proposal",
[
Expand Down Expand Up @@ -159,3 +169,32 @@ def test_float_proposal(input_ids, proposal):
result,
torch.tensor(proposal),
)


@pytest.mark.parametrize(
"input_ids, proposal, with_empty",
[
([[]], [[-math.inf, 1.0, 1.0, 1.0, 1.0, -math.inf, -math.inf, 1]], True),
(
[[]],
[[-math.inf, 1.0, 1.0, 1.0, 1.0, -math.inf, -math.inf, -math.inf]],
False,
),
([[3]], [[1.0, -math.inf, 1.0, -math.inf, 1.0, -math.inf, -math.inf, 1]], True),
(
[[3]],
[[1.0, -math.inf, 1.0, -math.inf, 1.0, -math.inf, -math.inf, -math.inf]],
False,
),
],
)
def test_empty_strings(input_ids, proposal, with_empty):
model = ModelWithEmpty()
generator = generate.float(model, allow_empty_tokens=with_empty)

logits = torch.ones(len(model.tokenizer.vocabulary))
result = generator.create_proposal(torch.tensor(input_ids), logits)
assert torch.equal(
result,
torch.tensor(proposal),
)
9 changes: 6 additions & 3 deletions tests/text/test_fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,12 @@ def test_create_fsm_index_tokenizer():

tokenizer = TransformersTokenizer("gpt2")

res = create_fsm_index_tokenizer(regex_fsm, tokenizer)
states_to_token_subsets, empty_token_ids = create_fsm_index_tokenizer(
regex_fsm, tokenizer
)

assert len(res) / num_fsm_states > 0.94
assert not empty_token_ids
assert len(states_to_token_subsets) / num_fsm_states > 0.94


@pytest.mark.skip(reason="Only for local profiling")
Expand All @@ -336,7 +339,7 @@ def test_regex_index_performance():
tokenizer = TransformersTokenizer("gpt2")

# Pre-compile Numba functions
res = create_fsm_index_tokenizer(regex_fsm, tokenizer)
res, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer)
assert len(res) > 1

profiler = LineProfiler(create_fsm_index_end_to_end)
Expand Down

0 comments on commit 9f416a5

Please sign in to comment.