Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vLLM integration for the generation API #163

Closed
rlouf opened this issue Jun 25, 2023 · 6 comments
Closed

Add vLLM integration for the generation API #163

rlouf opened this issue Jun 25, 2023 · 6 comments
Labels
enhancement text Linked to text generation

Comments

@rlouf
Copy link
Member

rlouf commented Jun 25, 2023

vLLM allows for subtantially faster inference via smart management of the KV-cache. The library proposes seemless integration with some of the most popular HuggingFace models. We suggested in #150 to re-use some of the ideas in this paper/library for Outlines.

In a first iteration we can dispatch the sequence generation functions introduced in #139 to use vLLM's user-facing APIS, although the use would be limited to the simplest generation methods. Longer term, we should look into vLLM's internals and see if we can make Outlines compatible.

@rlouf rlouf added text Linked to text generation enhancement labels Jun 25, 2023
@louisoutin
Copy link

Agree, would be nice to see an integration for vLLM into outlines

@amir-in-a-cynch
Copy link

Is there any progress on this? I'm using vllm with outlines - I'd be willing to help out / try and put out a PR if there's interest in any community contribution here.

@rlouf
Copy link
Member Author

rlouf commented Nov 8, 2023

There's always an interest in community contributions :) However, this requires substantial changes in Outlines' codebase that has far-reaching design implications. We need a bit more time. After that design change the integration should be very straightforward.

@amir-in-a-cynch
Copy link

amir-in-a-cynch commented Nov 8, 2023

There's always an interest in community contributions :) However, this requires substantial changes in Outlines' codebase that has far-reaching design implications. We need a bit more time. After that design change the integration should be very straightforward.

Oh absolutely. Not trying to nag. OSS work is voluntary, unpaid and time consuming. Just showing interest, and willing to help out if it'd accelerate things. Thanks for the update and the work you do!

@simon-mo
Copy link

Hi @rlouf, outline is really awesome!

I'm from the vLLM team and I'm quite excited about guideline's approach. We recently added the logits_processors API and I tried to integrate vLLM with outlines. Here's a hacky version that I got working

import outlines.models as models
from outlines.text.generate.regex import Regex

from vllm import LLM, SamplingParams
import torch

prompts = [
    "What is the IP Address of Google",
]

# We are not using this model for actual inference. But it seems to be required for Regex class.
model = models.transformers("gpt2-medium", device="cuda")

regex_processor = Regex(
    model,
    regex_string=r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
    max_tokens=16,
)

def create_proposal(token_ids, logits):
    token_ids = torch.Tensor([token_ids]).long().to("cuda")
    return  regex_processor.create_proposal(token_ids, logits)

sampling_params = SamplingParams(
    logits_processors=[create_proposal],
    max_tokens=16,
)

# Create an LLM in vLLM.
llm = LLM(model="gpt2-medium")

for _ in range(10):
    outputs = llm.generate(prompts=prompts, sampling_params=sampling_params, use_tqdm=False)
    regex_processor.last_fsm_states.clear() # We have to do this because we are sharing a FSM.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

However, there are a few issues:

  • The compilation time is about 10s, which is pretty long if it is used in online setting and supplied by the users. (I posted an issue here Tokenizer Aware FSM Construction #383)
  • There is no "processor" only API. Outlines currently "wraps" transformers. From vLLM's point of view, we want to continue to expose users the logits_processors in SamplingParams. It would be great if the users can use both library without one wrapping the other.
    • As in above code, the model has to be initiated twice due to the reliance of model in Regex class
    • Does Refactor the sequence generation #366 address this? If so what would the API be like?
  • There's also some issues in sharing Regex processor is impossible here due to statefulness of FSM.

@rlouf
Copy link
Member Author

rlouf commented Nov 21, 2023

Hi @rlouf, outline is really awesome!

Thank you, so is vLLM!

Compilation might be an issue when calling the model as a one-off. When deployed in a service this hardly matters since it only needs to happen once. We'll update the examples to show compilation and inference can be separated. In the following compilation only happens when generator is created:

import outlines

model = outlines.models.transformers("gpt2")
generator = outlines.text.generate.regex("[a-Z]")

result = generator("prompt")
  • There is no "processor" only API. Outlines currently "wraps" transformers. From vLLM's point of view, we want to continue to expose users the logits_processors in SamplingParams. It would be great if the users can use both library without one wrapping the other.

I agree, this is partly what motivated #366. Everything is coupled and it's not ideal.

  • Does Refactor the sequence generation #366 address this? If so what would the API be like?
  • There's also some issues in sharing Regex processor is impossible here due to statefulness of FSM.

Yes, the refactor in #366 addresses this and removes the statefulness. Here's the design that is being implemented:

Untitled scene

Generator holds the state. When it calls the Finite-State Machine with a state id, it gets the logits mask back. One it has generated the new token it queries the FSM for the nex state id by passing the new token id and the current state. Does this make sense?

It should then be much easier to integrate outlines into vLLM by only having to pass the FSM part to the logits processor.

This issue was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement text Linked to text generation
Projects
None yet
Development

No branches or pull requests

4 participants