Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CFG-guided generation #391

Merged
merged 10 commits into from
Dec 19, 2023
Merged

Conversation

benlipkin
Copy link
Contributor

@benlipkin benlipkin commented Nov 26, 2023

Thanks to the outlines team for putting together a great library. Had been rolling my own limited solutions for constrained generation tasks until I came across this. outlines has now been subbed into most of my use cases.

One additional feature that I've waiting on a nice solution for is CFG-guided generation. Since I like the outlines API, I took a peak at the library, and have put together an initial draft of a CFG class here along with a working example here.

The high-level summary is as follows:

  1. lark interactive parser proposes regex for the next valid terminal.
  2. the Regex class's create_proposal method is called repeatedly until the regex is completed.
  3. cfg parser advances to next state, a new regex class is built, and it is then called (with an updated start index so as to skip generations from previous regex).
  4. this repeats until an eos token is sampled or until the only valid terminal left by the cfg parser is the eos token

This has been checked for batch_size>1, but still needs to be stress-tested quite a bit more in terms of the range of grammars, etc. Looking forward to comments and feedback.

@benlipkin
Copy link
Contributor Author

have extended testing and found bug, defer review until next commit (in progress).

@benlipkin
Copy link
Contributor Author

thanks, should be good now!

@rlouf
Copy link
Member

rlouf commented Nov 30, 2023

Could you run pre-commit locally to fix the formatting issues? Also we may consider merging this PR before #366 since it is a big interface change, unless you are willing to make the changes to this PR to match the new interface.

@rlouf rlouf added enhancement structured generation Linked to structured generation labels Nov 30, 2023
@benlipkin benlipkin changed the title initial draft of cfg guided generation cfg-guided generation Nov 30, 2023
@benlipkin
Copy link
Contributor Author

Thanks, pre-commit hooks should be passing now.

I agree that merging before #366 makes sense.

Taking a peak at that interface refactor, it looks like it should ultimately simplify this implementation as well, which is great. I should note that I am in a bit of a deadline heavy period now, so would not expect to be able to contribute in the next ~2 weeks, but after that would be happy to refactor this if no one else has gotten to it by then.

@brandonwillard brandonwillard force-pushed the cfg-guided-generation branch 2 times, most recently from d9ddc72 to 6013abf Compare December 4, 2023 20:40
outlines/text/generate/cfg.py Outdated Show resolved Hide resolved
outlines/text/generate/cfg.py Outdated Show resolved Hide resolved
@brandonwillard
Copy link
Contributor

brandonwillard commented Dec 4, 2023

FYI: I've rebased and squashed. Still reviewing, though...

outlines/text/generate/cfg.py Outdated Show resolved Hide resolved
outlines/text/generate/cfg.py Outdated Show resolved Hide resolved
@rlouf
Copy link
Member

rlouf commented Dec 11, 2023

@benlipkin I ended up merging #366 because it was blocking many other improvements. It should be relatively easy to adapt this PR to make it work with the new interface. You will need to implement a CFGFSM class in outlines.fsm.fsm.py. Overall I think this will simplify the code.

@benlipkin
Copy link
Contributor Author

Thanks @rlouf and apologies for the delays here. My other deadline is now in, and I will begin this refactor today.

@rlouf
Copy link
Member

rlouf commented Dec 11, 2023

No problem, thank you for taking this on!

@benlipkin
Copy link
Contributor Author

While working on this, I noticed a bug (based on my understanding of how tokens are expected to be tracked by a user) in how FSM.num_tokens_generated is being updated, both when using a batch_size>1 as well as when reusing an existing generator on a new batch of inputs.

It is my understanding that max_tokens as set when constructing an FSM is intended to track the max tokens to be generated for each sequence. Instead, it was previously being incremented each time FSM.next_state was called. This led to the following behavior. If max_tokens=10 and batch_size=5, the generations are 5x2 token sequences, with the final state of the generator actually being 15. This is because it increments on each sequence, until reaching 10, then does one more pass through, still incrementing, but since max_tokens<=num_generated_tokens, it just sets the FSM state to EOS. If this generator is then called again on another set of 5 inputs, the returned generations will be 5x1 token sequences. This is because the first tokens will always generate before FSM.next_state is called. Then, it will see the existing num_tokens_generated>max_tokens and immediately set the next FSM state to EOS after the first generated element of each.

Instead I propose the following:

In outlines/generate/generator.py, each time an FSM method is called as part of the loop over a batch, the index of the element is passed. If idx is 0 (also default value for protocol), only then num_tokens_generated is incremented (seen in outlines/fsm/fsm.py). Then, each time a generator is called on a new batch of inputs, there is a call to FSM.reset(), which sets num_tokens_generated back to 0 (seen in outlines/generate/api.py).

This reset method will also be important for CFGFSM, which will also need to track some state as part of the object including partial completions to determine the current parser state for each sequence uniquely. Regarding this class implementation, I have a general plan to approach. Some aspects are simpler with respect to old API, but some will be more complicated. In particular, it is not yet fully clear to me what the best way will be to determine when to let a regex continue generating vs to shift to the next parser state if either are possible options (previously we used the sampler to determine whether the regex would generate an EOS next and then shifted to next regex if so else removed option to generate EOS from mask and kept generating). I will share more info on this later when I've outlined the rest of the interface and edge cases.

@rlouf
Copy link
Member

rlouf commented Dec 13, 2023

I'll take a look at that bug tomorrow. We were thinking about moving this logic to SequenceGenerator

@benlipkin
Copy link
Contributor Author

benlipkin commented Dec 13, 2023

Have outlined the rest of the PR. This implementation should be sound, in that it will always generate strings within the language, but it is not complete, in that there are strings in the language that cannot be generated.

This surrounds a particular branching point of determining when an underlying regex is done, thereby allowing us to transition to the next CFG parser state, which I point out here. If a particular regex's only next state is generating the EOS token, then we can easily transition to the next regex. But, if the regex is in a final state, yet could generate more, e.g., the string "a" in the language aa*, then we face a decision point. With some probability P, we should either assume the regex done and transition to the next regex, or with probability 1-P, we should remove EOS from the valid next set of states and continue generating. Previously, the CFG class had access to the underlying model logits and sampler, and so was able to drive this branching point, but the current interface does not support this. In theory, I could pass these in from outlines/generate/generator.py, but this seems clunky, so please advise on the preferred approach.

Relatedly, the RegexFSM class is also not complete. If is_generation_finished returns True in outlines/generate/generator.py, generation stops, but the implementation of is_final_state for RegexFSM is return state in self.final_states. This means, that if generating 100 samples from the language aa*, the returned value will always be ["a"]*100, rather than allowing for the generation of the infinitely many other strings in the language. A patch here might be return state==FSMState(-1), i.e. only end if EOS was already generated. This "wastes" a token for some languages, but fixes this edge case.

@benlipkin
Copy link
Contributor Author

Also, placing this in a separate comment, but @brandonwillard , I tried out plugging in the PartialLark implementation, but the set of strings returned by interactive.accepts() here is not updated to reflect completion of previous tokens. Haven't taken an in-depth look at that class yet, but flagging for you in the meantime.

@brandonwillard
Copy link
Contributor

Also, placing this in a separate comment, but @brandonwillard , I tried out plugging in the PartialLark implementation, but the set of strings returned by interactive.accepts() here is not updated to reflect completion of previous tokens. Haven't taken an in-depth look at that class yet, but flagging for you in the meantime.

No problem, we can do all that in a follow-up.

Copy link
Contributor

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: We'll need to rebase to remove the merge commit and get things matched up with main. One of us can do that, @benlipkin, once you've reached a good stopping point, if you'd like.

@benlipkin
Copy link
Contributor Author

benlipkin commented Dec 13, 2023

Just rebased. That edge case I raised still needs to be covered at some point, but otherwise should be set.

@rlouf rlouf changed the title cfg-guided generation Add CFG-guided generation Dec 14, 2023
@rlouf
Copy link
Member

rlouf commented Dec 14, 2023

Great addition! I rebased the code on main to remove the merge commit. However when I ran the examples I noticed that type constraints were not respected: for instance for the calc_grammar examples I get strings in place of NUMBER. Is it possible to make the generator use the regexes defined in outlines/fsm/json_schema.py?

@benlipkin
Copy link
Contributor Author

This is because the calc_grammar is slightly more expressive than the name suggests and supports variable assignments and variable evaluation in expressions.

If we remove the following lines from the grammar:

| NAME "=" sum    -> assign_var
| NAME             -> var
%import common.CNAME -> NAME

leaving:

calc_grammar = """
    ?start: sum

    ?sum: product
        | sum "+" product   -> add
        | sum "-" product   -> sub

    ?product: atom
        | product "*" atom  -> mul
        | product "/" atom  -> div

    ?atom: NUMBER           -> number
         | "-" atom         -> neg
         | "(" sum ")"

    %import common.NUMBER
    %import common.WS_INLINE

    %ignore WS_INLINE
"""

then, generations look as follows:

SUCCESS 7-9+7-2*9/5-4+3*.6+00+.1+6-.00-4
SUCCESS 8
SUCCESS 00/7/.8/3
SUCCESS 2+2+.6
SUCCESS 00
SUCCESS 1-7
SUCCESS 00+(7/00)*7+4/0
SUCCESS 5+2+5
SUCCESS 6+1
SUCCESS 3+8/3/3
SUCCESS (00)+-1+(2/0*3-6-00+2*9)
SUCCESS 3+-2+1/.4--3*5/5/6+.6
SUCCESS -00--1+-8
MAXTOKEN 3+00-1*5-.1-8*7*9+00+00*7/9/(00*4/
SUCCESS 3-.00+00+4
SUCCESS 5/6-2
SUCCESS 3/-2*2
SUCCESS 1*-8+5
SUCCESS 7--(8+6)
SUCCESS 5*8*-.8/3/0+7+4/8+.00

@rlouf
Copy link
Member

rlouf commented Dec 18, 2023

There is a merge conflict to resolve. We will need a few tests for FSMCFG and then should be good to merge.

@benlipkin
Copy link
Contributor Author

Thanks, have resolved the merge conflict and added a few test cases. Let me know if others would be beneficial. I also outlined the previously raised edge case in a commented-out test for subsequent consideration once there is consensus on how to use the sampler to drive that decision point.

@rlouf
Copy link
Member

rlouf commented Dec 18, 2023

LGTM, we just need to rebase on main to remove the merge commit. It is also preferable to mark the failing test with the decorator @pytest.mark.xfail(reason="the reason") instead of commenting it out. Easier to track down the line.

@rlouf
Copy link
Member

rlouf commented Dec 19, 2023

Thank you for adding the tests! I'm merging and cutting a new release

@rlouf rlouf merged commit dce9265 into outlines-dev:main Dec 19, 2023
5 checks passed
@EFord36
Copy link

EFord36 commented Dec 19, 2023

Hi - I subscribed to this because this feature is quite useful for me.

I wonder if it's also worth putting in the README? Both in the features bullets and a longer version to explain it?

There are other types of guided generation described there, so it seems a shame to miss this one (but maybe as someone interested in this feature, I'm overestimating how many others are interested).

@rlouf
Copy link
Member

rlouf commented Dec 19, 2023

Sure, I would be happy to review a PR if you feel up to adding a simple example in the README. Can also be added in the "features" bullet points :)

@EFord36
Copy link

EFord36 commented Dec 22, 2023

Sure, I would be happy to review a PR if you feel up to adding a simple example in the README. Can also be added in the "features" bullet points :)

Just went to look at this and saw you did it yourself - thanks, appreciate it, and Outlines!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement structured generation Linked to structured generation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Create a parser demo that uses the new Sequence functionality.
4 participants