Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Grammars #2105

Closed
wants to merge 77 commits into from
Closed

Add Grammars #2105

wants to merge 77 commits into from

Conversation

lapp0
Copy link

@lapp0 lapp0 commented Dec 14, 2023

Fixes #1229

Implement incremental LALR / Regex parser to determine legal-next-token set.

Rendered documentation

Try it

I smoke tested with

dockerfile: https://hub.docker.com/r/lapp0/vllm_grammar_branch (commit 2b2b024)

async def fetch_response(session, doc, grammar, api_url_base, timeout=60):
    prompt = get_prompt(doc)
    grammar = get_grammar(doc)
    headers = {"User-Agent": "Test Client"}
        pload = {
        "model": "YOUR HF MODEL URI",
        "prompt": prompt,
        "grammar": grammar,
        # optional:
        #"n": 8,
        #"use_beam_search": True,
        #"temperature": 0.0,
        #"add_generation_prompt": False,
        #"max_tokens": 4096,
        #"logprobs": True
    }

    async with session.post(api_url=f"{api_url_base}/v1/completions", headers=headers, json=pload, timeout=timeout) as response:
        return await response.json()

async def do_smoke_test(api_url_base, max_concurrent=16, timeout=2000):
    documents = json.load(open("smoke_testing_docs.json"))[:256]
    results = {}

    connector = aiohttp.TCPConnector(limit=max_concurrent)
    session_timeout = aiohttp.ClientTimeout(
        total=None,
        sock_connect=timeout,
        sock_read=timeout,
    )
    async with aiohttp.ClientSession(connector=connector, timeout=session_timeout) as session:
        tasks = [fetch_response(session, doc, api_url_base, timeout) for doc in documents]
        for doc, response in (await tqdm.asyncio.tqdm.gather(*tasks)):
            results[doc] = response

TODO

  • InteractivePredictiveLALRParser
    • Load Lark EBNF grammars
    • Track terminals and state transitions
    • Determine token validity whether it partially or wholly completes the terminals pattern
  • TokenTrie
    • Efficiently retrieve candidate tokens
  • NextTokenValidator
    • Update parser state
    • Get valid token IDs given a tokenizer vocabulary and the completion-in-process' text
  • GrammarLogitProcessor
    • Implement def __call__(self, token_ids, logits) , which updates the parsers state with new tokens and filters based on NextTokenValidator's valid token IDs
  • Testing
    • Performance tests: see docs
  • [/] Documentation
  • Clean up tests, right now they're all lurking in grammar.py
  • Handle EOS token
  • finish test cases
  • IncrementalParserState
    • refactor into immutable parser state

Ramblings from previous implementation:

Grammar Token Filter Algorithm

The GTF algorithm involves calculating the set of valid next-tokens given an incomplete sequence and a grammar.

Its core components are

  • A set of tokens, for which every token is either a single character or can be generated with a combination of other tokens
  • An interactive parser which can determine the validity of incomplete sequencess.

Lazy Approach

The simplest algorithm is

def get_valid_tokens(base_sequence, token_vocabulary):
    valid_tokens = set()
    for token in token_vocabulary:
	    if parser.is_valid_sequence(base_sequence + token)
	        valid_tokens.add(token)
    return valid_tokens

This approach has two core inefficiencies:

  • If the token "foobar" is valid, then we already know "foo" is valid, thus redundant work is performed
  • Parser applies state transitions for base_sequence once for every token

This PR's Approach

The current GTF algorithm improves on the lazy approach in two aspects:

  • token_vocabulary is a trie, allowing us to check "foo", and if invalid, we know "foobar", "foobaz", and "foobarbaz" are also invalid.
  • The parser is interactive, meaning it doesn't need to recalculate the base_sequence each time.

Current approach algorithm is a depth first search of the token trie with a base_sequence-warmed parser.

def get_valid_tokens(parser, token_trie, trie_root=""):
    valid_tokens = set()
    for token in token_trie.children(trie_root):
        if parser.is_valid_next_token(token):
            child_parser = parser.step_seq(token)
                valid_tokens.add(token)
                valid_tokens |= get_valid_tokens(child_parser, token_trie, token)
    return valid_tokens

The main weakness of this implementation involves regular expressions. If a terminal rule is a regular expression, an incomplete match must be searched for redundantly.

Optimal (Future) Approach

The optimal GTF approach involves all terminal rules being a single character. All terminals, including regular expressions must be decomposed.

For example the regular expression

\d{5}(-\d{4})?

Must be decomposed into

digit = "\d"
five_digits = digit, digit, digit, digit, digit;
four_digits = digit, digit, digit, digit;
optional_suffix = "-", four_digits;
zipcode = five_digits, [optional_suffix];

Additionally we can use a helper function legal_chars(character_expr) which retrieves all characters legal within a character regexp, e.g.

  • legal_chars("\d") = set(["0", "1", "2", ...])
  • legal_chars("[ae") = se["a", "e"]

With this optimization the GTF algorithm would be as follows:

def get_valid_tokens(parser, token_trie, token_trie_roots=None):
    valid_tokens = set()
    for next_terminal_expr in parser.get_next_terminals_exprs():
        parser_next_chars = legal_chars(next_terminal_expr)
        legal_token_prefixes = token_trie.children(token_trie_roots) | parser_next_chars
        if legal_token_prefixes:
            child_parser = parser.transition(next_terminal_expr)
            legal_token_suffixes = get_valid_tokens(child_parser, token_trie, legal_token_prefixes)
            valid_tokens |= trie.combine(legal_token_prefixes, legal_token_suffixes)
    return valid_tokens

This function requires only applying a state transition once for every transition which is legal within the token set. As opposed to the current implementation which applies a state transition once for each unique token trie node.

Breaking down into single character terminals provides another advantage: we don't have to recompute a regular expression partial redundantly, if foo matches (foo|bar)(bazbif), we don't need to recalculate the entire regex for foobaz again. In fact, we don't compute regular expresisons at all, we simply generate the valid character set for a given atomic character expression and intersect it with the tries valid token prefix set.

Example

I use a simple Thompson's-style regex to generate the eNFA dict via automata_toolkit.

Sample code which assigns random values to logits and generates a grammer-constrained completion:

    regexp = r"(large )?(language )((models )+(inference engines ))(are )((useful)+((very )*complex))."

    sample_from_logits = lambda lgts: np.random.choice(len(lgts), p=np.exp(lgts)/np.sum(np.exp(lgts)))

    for i in range(4):

        logit_processor = TokenConstraintLogitProcessor(
            tokenizer=tokenizer,
            nfa=EpsilonNFA(nfa=regex_to_nfa.regex_to_nfa(regexp)),
        )

        token_ids = []
        while True:
            logits = logit_processor(
                token_ids=token_ids,
                logits=np.random.uniform(-10, 10, len(tokenizer.vocab))
            )
            new_token_id = sample_from_logits(logits)
            token_ids.append(new_token_id)
            if new_token_id == tokenizer.eos_token_id:
                break
        print(f"run #{i}")
        print("\ttokenid", token_ids)
        print("\ttokens:", [tokenizer.decode(tok_id, ) for tok_id in token_ids])
        print("\tresult:", tokenizer.decode(token_ids, skip_special_tokens=False))

Output:

regexp: r"(large )?(language )((models )+(inference engines ))(are )((useful)+((very )*complex))."

run #0
	tokenid [2220, 28712, 28721, 104, 305, 28708, 113, 2851, 28708, 490, 3418, 5149, 28713, 264, 267, 1001, 112, 452, 720, 49, 2]
	tokens: ['la', 'r', 'g', 'e', 'l', 'a', 'n', 'gu', 'a', 'ge', 'mo', 'del', 's', 'a', 're', 'co', 'm', 'pl', 'ex', '.', '</s>']
	result: large language models are complex.</s>
run #1
	tokenid [2220, 7879, 104, 28705, 28714, 2374, 465, 4319, 9417, 358, 17048, 597, 104, 28705, 675, 452, 720, 49, 2]
	tokens: ['la', 'rg', 'e', '', 'l', 'angu', 'age', 'inf', 'eren', 'ce', 'engines', 'ar', 'e', '', 'com', 'pl', 'ex', '.', '</s>']
	result: large language inference engines are complex.</s>
run #2
	tokenid [2220, 7879, 104, 28705, 4730, 120, 465, 968, 1190, 264, 267, 1429, 28724, 4630, 49, 2]
	tokens: ['la', 'rg', 'e', '', 'lang', 'u', 'age', 'mod', 'els', 'a', 're', 'ver', 'y', 'complex', '.', '</s>']
	result: large language models are very complex.</s>
run #3
	tokenid [16962, 543, 113, 28721, 120, 465, 4319, 9417, 28717, 104, 2536, 1303, 597, 104, 332, 28713, 797, 120, 28714, 49, 2]
	tokens: ['large', 'la', 'n', 'g', 'u', 'age', 'inf', 'eren', 'c', 'e', 'eng', 'ines', 'ar', 'e', 'u', 's', 'ef', 'u', 'l', '.', '</s>']
	result: large language inference engines are useful.</s>

Please observe that ["la", "rg", "e"] and ["large"] are both valid tokens within the grammar, and either may be generated.

@xuy
Copy link

xuy commented Dec 21, 2023

Thanks for putting this together @lapp0 .
I managed to integrate it with the rest of the vllm and got legit outputs!
A minor change I made was that the return from GrammarLogitProcessor.__call__ should be tensor, not a list.
I made a minor change to make it work, hope it helps.

N = len(self.tokenizer.vocab)
mask = torch.zeros(N, dtype=torch.bool)
valid = torch.tensor(valid_token_ids, dtype=torch.long)
mask[valid] = True
logits[~mask] = float('-inf')
return logits

@lapp0
Copy link
Author

lapp0 commented Dec 21, 2023

Appreciate your review, fix, and interest @xuy. Will integrate that after I'm done with some bug fixes!

@brucethemoose
Copy link

Does this only work with the OpenAI API at the moment? If so, could it be added to the vllm api as well?

@l4b4r4b4b4
Copy link

Works nicely so far. I noticed the preprocessing for batching being done on only one core and hence significantly stalling the process. Is that due to grammar implementation? And is there a way to fix that, to either use GPU or more than a single core?

@brucethemoose
Copy link

@lapp0 Could you post your multiprocessing branch, even if its incomplete? I've been trying to implement it myself, but it seems I can't get it quite right.

@lapp0
Copy link
Author

lapp0 commented Jan 10, 2024

@brucethemoose It's pretty poorly implemented, but here you go: https://github.com/lapp0/vllm/tree/grammar-multiprocessing

I've been working on integrating some of my caching changes into https://github.com/outlines-dev/outlines which already has regex-based guidance for vLLM.

@viktor-ferenczi
Copy link
Contributor

viktor-ferenczi commented Jan 13, 2024

Tested the grammar support from your branch.

Additional changes I made:

  • Rebased it on latest vLLM main
  • Added receiving the grammar parameter to ChatCompletionRequest as well and handled it the same way in the v1/chat/completion request handler, since I use an instruct fine-tuned model via that path.

Model: TheBloke/deepseek-coder-33B-instruct-AWQ
System: "You are a helpful AI assistant. You give concise answers. If you do not know something, then say so."
User: "Write down the first 10 prime numbers as a comma separated list on a single line. Do not write anything else."

Without the grammar the model gives this response:

"2, 3, 5, 7, 11, 13, 17, 19, 23, 29"

So in the grammar I intentionally denied any use of white-space, so the expected output must be:

"2,3,5,7,11,13,17,19,23,29"

Grammar:

?start: SIGNED_NUMBER ( "," SIGNED_NUMBER )*
%import common.SIGNED_NUMBER

While it conforms to the grammar it fails to produce the two digit prime numbers:

"2,3,5,7,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,"

It may happen that the grammar code somehow denies it from writing 11 there, e.g. it cannot write a number with multiple digits.

Changed the grammar to be more strict and simpler:

?start: DIGIT+ ( "," DIGIT+ )*
%import common.DIGIT

With this grammar the model produces the primes, but cannot stop. Therefore there is a problem in the code denying it to generate the EOS token. Generating the stop token should be allowed wherever it is consistent with the grammar.

"2,3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,"

Grammar support would be really awesome to have for my use case. I actually started implementing Lark support and already figured out the PR's algorithm in my tests outside vLLM when I found your PR. It is really great that you had this much progress already, so there is a chance to have grammar support soon.

Even if we cannot use regex in our grammar, having any kind of grammar support would still be a huge win. Also, the grammar support would allow for reliable function calling, which is also in the works (#2360). They refactor the spaghetti code in the OpenAI compatible server in that first PR.

In llama.cpp there is a grammar named GBNF, which is an EBNF variant. That already works and its integration with the sampling code can give us some ideas on how to optimize this in vLLM.

@viktor-ferenczi
Copy link
Contributor

viktor-ferenczi commented Jan 13, 2024

When I change the grammar to allow whitespace, then it can generate the primes properly:

"2, 3, 5, 7, 11, 13, 17, 19, 23, 29"

Grammar:

?start: _WS? DIGIT+ ( _WS? "," _WS? DIGIT+ )* _WS?
%import common.DIGIT
%import common.WS -> _WS

It worked without code changes and could stop. I do not know why it cannot stop if no white-space is allowed. The model writes spaces only after the commas and no newline is generated at the end of completion.

Narrowed it down to this grammar. It works, but has to produce a newline at the end, so it can stop:

?start: DIGIT+ ( "," DIGIT+ )* _WS?
%import common.DIGIT
%import common.WS -> _WS

Output as a Python string:

"2,3,5,7,11,13,17,19,23,29\n"

So the actual bug is that the grammar does not let the LLM to generate a stop token if the grammar does not allow white-space at the very end. At least that seems to be the case based on the few tests I've done. I'm not sure whether it is a limitation of this specific LLM due to how it was trained (must write a newline before EOS) or a bug in the sampler integration of the grammar (it does not allow EOS in that case).

The CPU overhead of the grammar is indeed horrible. Speed is down from 32T/s to 8.4T/s with the above very simple grammar.

@lapp0
Copy link
Author

lapp0 commented Jan 13, 2024

@viktor-ferenczi

The parser doesn't handle ambiguous terminals well. Could you try converting them to a rule? Something along the lines of

    signed_number: ["+"|"-"] number 
    number: float | int 
    float: int exp | decimal exp? 
    decimal: int "." int? | "." int 
    exp: ("e"|"E") 
    signed_int signed_int: ["+"|"-"] int 
    int: DIGIT+ DIGIT: "0".."9"

And yes, the speed is bad. Outlines addresses this by precompiling the regex FSM and using Numba. I'm leaning heavily towards thinking vLLM should be a strong, simple inference engine and outlines should be a wrapper on top for grammars.

Outlines vLLM CFG implementation merged yesterday dottxt-ai/outlines#517

@viktor-ferenczi
Copy link
Contributor

viktor-ferenczi commented Jan 13, 2024

The grammar you suggested crashes vLLM with this exception:

TypeError: UnexpectedToken.__init__() missing 2 required positional arguments: 'token' and 'expected'

The traceback is useless because of the use of Ray (2 GPUs).

Performance: I was running vLLM with cProfile and executed the completion some 50 times in about 2 minutes. Found the grammar responsible for only ~550ms of CPU runtime, so I don't see from the profiling data where the experienced slowdown is. Grammar's CPU load is 60-70% of a core, so it does not seem to be CPU bound there. I guess the load does not show up on the CPU or the Python profiler, but introduced by the use of Tensor (GPU RAM access?) or similar. I don't know enough Torch and CUDA yet to tell exactly.

I'm leaning heavily towards thinking vLLM should be a strong, simple inference engine and outlines should be a wrapper on top for grammars.

Where would you put the grammar support? If we keep it inside vLLM, then it can be used via the REST APIs. That's what I prefer, at least for my use case. It allows for hosting the LLM separately from the application and better scalability, all without having to write a custom server for each application or forcing the application to run the LLM directly in-process.

@viktor-ferenczi
Copy link
Contributor

The exception due to the grammar:

...
  File "/home/viktor/dep/vllm-contrib/vllm/model_executor/layers/sampler.py", line 155, in _apply_logits_processors
    logits_row = logits_processor(token_ids, logits_row)
  File "/home/viktor/dep/vllm-contrib/vllm/grammar.py", line 472, in __call__
    return ray.get(result_id)
... ray ...
ray.exceptions.RaySystemError: System error: Failed to unpickle serialized exception
... ray ...
TypeError: UnexpectedToken.__init__() missing 2 required positional arguments: 'token' and 'expected'

So it cannot relay Lark's UnexpectedToken error. It is also not handled properly and turned into a Bad Request error by the API Server, apparently.

@lapp0
Copy link
Author

lapp0 commented Jan 13, 2024

Performance: I was running vLLM with cProfile and executed the completion some 50 times in about 2 minutes. Found the grammar responsible for only ~550ms of CPU runtime, so I don't see from the profiling data where the experienced slowdown is. Grammar's CPU load is 60-70% of a core, so it does not seem to be CPU bound there. I guess the load does not show up on the CPU or the Python profiler, but introduced by the use of Tensor (GPU RAM access?) or similar. I don't know enough Torch and CUDA yet to tell exactly.

Are you using multiple GPUs? I'm seeing a substantial slowdown when passing the tensors to the logits processor ray actor.

Where would you put the grammar support? If we keep it inside vLLM, then it can be used via the REST APIs. That's what I prefer, at least for my use case. It allows for hosting the LLM separately from the application and better scalability, all without having to write a custom server for each application or forcing the application to run the LLM directly in-process.

https://outlines-dev.github.io/outlines/reference/vllm/

@viktor-ferenczi
Copy link
Contributor

@lapp0 Tried the outlines.serve.serve way. The JSON schema and Regex work, but the grammar (cfg) does not. See the outlines bug report on this. Also, that solution does not work with tensor parallel at all (see bug ticket). It looks like everything is implemented, just not reliable yet.

@jqueguiner
Copy link

There are already libraries actively maintained for guided generation that can integrate with vLLM, like Outlines. I would be wary of introducing code that is tangentially related to this library and will require a substantial amount of maintenance when this can be solved by an import. Why not contribute this code to these libraries and import them here instead?

http://outlines-dev.github.io/outlines/reference/vllm/

@viktor-ferenczi
Copy link
Contributor

@jqueguiner The custom logits processors need some more information to be passed to avoid having to patch vLLM the hard way. Primary example is a way to identify the sequence (seq_id) and maybe more. Please look into the implementation of outlines.serve.serve, specifically _patched_apply_logits_processors.

@rlouf
Copy link

rlouf commented Jan 15, 2024

The seq_id can probably be replaced with a hash of the token ids if that’s really the blocker. But that’s beside the point, even if we needed to pass seq_id, I agree with @jqueguiner that it’s an easier change for the vLLM team that requires substantially less maintenance over time.

@lapp0
Copy link
Author

lapp0 commented Jan 15, 2024

@jqueguiner The custom logits processors need some more information to be passed to avoid having to patch vLLM the hard way. Primary example is a way to identify the sequence (seq_id) and maybe more. Please look into the implementation of outlines.serve.serve, specifically _patched_apply_logits_processors.

This experimental change where the state is cached by the hash of the prior token ids is working for me so far:

outlines-dev/outlines@8b1ff9a#diff-f65ffb5f52b2e358c713ccb8f32a700769426c6c8b655f689e3cdccae07d22ac

@viktor-ferenczi
Copy link
Contributor

A hash on preceding tokens is even better than seq_id, because it would allow for further optimization should prompts be repeated.

@simon-mo
Copy link
Collaborator

Hi everyone, thank you so much for the very active discussion here. As vLLM maintainer, I want to express my sincere thanks for your enthusiasm. vLLM as a project is focused on optimizing LLM inference and provide a fully compatible OpenAI API; constrained decoding is not our strong suit, and we don't have the expertise to maintain it.

@lapp0 would you be able to consider closing this PR and merge into outlines instead? I think you mentioned it here. I would very much like to use outlines directly in vLLM after #2488 is merged. (or sooner, adding it to completion API is another option).

@lapp0 and @viktor-ferenczi, please let us know what interface and scheduling change on the vLLM side is needed to better support this functionality.

@lapp0
Copy link
Author

lapp0 commented Jan 19, 2024

Sure @simon-mo will follow up with you for any changes to vLLM which are necessary. Thanks for your enthusiastic support!

Closing in favor of outlines. A few changes necessary in outlines to consider guidance ready for vLLM:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for grammar
9 participants