Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Regex generation method #175

Merged
merged 5 commits into from
Jul 13, 2023
Merged

Conversation

rlouf
Copy link
Member

@rlouf rlouf commented Jul 7, 2023

This trivially generalizes #166 by defining a Regex class that can be initialized with any regex string. The integer function now instantiates Regex with the appropriate regex string.

  • General Regex class
  • Function integer to initialize Regex with a regex that only matches integers
  • Function float to initialize Regex with a regex that only matches integers
  • Handle EOS tokens for open-ended sequences

@rlouf rlouf added text Linked to text generation enhancement labels Jul 7, 2023
@brandonwillard brandonwillard linked an issue Jul 7, 2023 that may be closed by this pull request
Copy link
Contributor

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rlouf, I added the EOS handling and padding.

tests/text/generate/test_integer.py Outdated Show resolved Hide resolved
Copy link
Contributor

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: The current error in CI looks like a "determinism" problem involving the numbering/order of the FSM states.

@rlouf
Copy link
Member Author

rlouf commented Jul 10, 2023

I would like to add a test to define the behavior when there is no possible match in the vocabulary. Then it should be ready to merge.

@rlouf rlouf linked an issue Jul 11, 2023 that may be closed by this pull request
@rlouf
Copy link
Member Author

rlouf commented Jul 12, 2023

Now the Regex class raises an exception when the vocabulary does not allow to build sequences that match the input regex. Ready for review.

brandonwillard
brandonwillard previously approved these changes Jul 12, 2023
Copy link
Contributor

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rlouf, I just pushed a change that makes the FSM start from the previous state it was in: i.e. it avoids decoding and rerunning the FSM for the entire sequence on each iteration. Tell me if that looks fine; otherwise, everything else looks good to me.

@brandonwillard
Copy link
Contributor

Agh, looks like the generated token sequences change shape when using transformers!

@brandonwillard brandonwillard force-pushed the regex-generation branch 5 times, most recently from 7d0a35c to ec10981 Compare July 13, 2023 00:17
Comment on lines +47 to +55
# TODO: This check might be a little too strict, because I think that
# while some states are made unreachable by a vocabulary (and will not
# be present in the following set difference), there could still be
# paths to terminal states emanating from the states that are reachable.
states_with_transition = {x[1] for x in pstate_to_vocab.keys()}
if len(self.regex_fsm.states.difference(states_with_transition)) > 0:
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should look into this. Perhaps as a follow-up issue.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I opened #184 to track this.

Comment on lines 38 to 47
mask = create_mask_from_regex(vocabulary, "^[0-9]+$")
mask = create_mask_from_regex(vocabulary, r"(0|[+-]?[1-9][0-9]+?)")

return mask


def create_float_mask(vocabulary: Dict[str, int]) -> torch.BoolTensor:
"""Create a mask to generate floating point numbers."""
mask = create_mask_from_regex(vocabulary, r"^(([0-9]+)?([.]([0-9]*)?)?|[.][0-9]+)$")
mask = create_mask_from_regex(
vocabulary, r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests were getting a little flaky in CI, so I had to add these updates.

This allows one to add EOS transitions to the partial-parse-state-to-vocabulary
maps produced by map_partial_states_to_vocab.
brandonwillard and others added 2 commits July 12, 2023 19:44
This refactoring also removed the need for the antecedent mapping option in
`map_partial_states_to_vocab`.
Copy link
Contributor

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I think it's good to go now.

# Get the tokens we haven't already processed
readable_tokens = token_seq[last_token_idx:]
# excluding any EOS tokens
not_eos_mask = [
Copy link
Member Author

@rlouf rlouf Jul 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should never get a sequence with an EOS token here. Those are filtered out in Sequence.__call__. Is it still worth keeping this check?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you start it out with a sequence like [[10, 2, 0, 0]], you would only want to process the first two. That's what it should be doing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would that happen if you cannot get sequences with 0 by design?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe I had to add it for the tests at one point, but that might not longer be true.

@rlouf
Copy link
Member Author

rlouf commented Jul 13, 2023

I only have one minor comment regarding EOS tokens. Sequence.__call__ filters finished sequences, and Regex inherits from Continuation which marks a sequence as finished when an EOS token is found. We can open a follow-up issue for this.

@rlouf rlouf merged commit 34bc2fb into outlines-dev:main Jul 13, 2023
4 checks passed
@rlouf rlouf deleted the regex-generation branch July 13, 2023 10:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement text Linked to text generation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Manage EOS in regex-based generation Add a Float generation method
2 participants