Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CFG integration to VLLM #788

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ on:
push:
branches: [main]

env:
VLLM_TARGET_DEVICE: cpu

jobs:
style:
name: Check the code style
Expand Down
8 changes: 5 additions & 3 deletions outlines/generate/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from outlines.fsm.guide import CFGGuide
from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter
from outlines.integrations.vllm import CFGLogitsProcessor
from outlines.models import OpenAI
from outlines.models.llamacpp import LlamaCpp
from outlines.models.vllm import VLLM
Expand Down Expand Up @@ -39,9 +40,10 @@ def cfg_vllm(
cfg_str: str,
sampler: Sampler = multinomial(),
):
raise NotImplementedError(
"The CFG Logits processor is not available for the vLLM integration."
)
logits_processor = CFGLogitsProcessor(cfg_str, model.model)
generator = SequenceGeneratorAdapter(model, logits_processor, sampler)

return generator


@cfg.register(LlamaCpp)
Expand Down
80 changes: 79 additions & 1 deletion outlines/integrations/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import torch
from pydantic import BaseModel

from outlines.fsm.guide import RegexGuide
from outlines.fsm.guide import CFGGuide, RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.integrations.utils import adapt_tokenizer, convert_json_schema_to_str

Expand Down Expand Up @@ -149,3 +149,81 @@ def __init__(
schema_str = convert_json_schema_to_str(json_schema=schema)
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
super().__init__(regex_string=regex_string, llm=llm)


class CFGLogitsProcessor:
"""Bias vLLM generation based on a context-free grammar.

Attributes
----------
fsm
The finite state machine which is used to bias the logits.
"""

def __init__(self, cfg_string: str, llm: "LLM"):
"""Compile the FSM that drives the CFG-structured generation.

Parameters
----------
cfg_string
A context-free grammar in the EBNF format.
llm
The vLLM model.

Raises
------
ValueError
If the provided LLM instance in `CFGLogitsProcessor` neither has a
`tokenizer` attribute or a `get_tokenizer` method.
"""
if hasattr(llm, "get_tokenizer"):
tokenizer = llm.get_tokenizer()
elif hasattr(llm, "tokenizer"):
if hasattr(llm.tokenizer, "tokenizer"):
tokenizer = llm.tokenizer.tokenizer
else:
tokenizer = llm.tokenizer
else:
raise ValueError(
"The provided LLM instance in `RegexLogitsProcessor` neither has a "
"`tokenizer` attribute or a `get_tokenizer` method."
)
tokenizer = adapt_tokenizer(tokenizer=tokenizer)
self.fsm = CFGGuide(cfg_string, tokenizer)
self._fsm_state: DefaultDict[int, int] = defaultdict(int)

def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token.

Parameters
----------
input_ids
The tokens of the current sentence.
scores
The logits of the current sentence.

Returns
-------
torch.Tensor
The biased logits.
"""
seq_id = hash(tuple(input_ids))

# Initialize the FSM state dictionary if the input_ids are empty, as this means
# that the input_ids are the first tokens of the sequence.
if len(input_ids) > 0:
last_token = input_ids[-1]
last_seq_id = hash(tuple(input_ids[:-1]))
self._fsm_state[seq_id] = self.fsm.get_next_state(
state=self._fsm_state[last_seq_id], token_id=last_token
)

allowed_tokens = self.fsm.get_next_instruction(
state=self._fsm_state[seq_id]
).tokens

mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device)
mask[allowed_tokens] = 0
biased_scores = scores + mask

return biased_scores
12 changes: 5 additions & 7 deletions tests/generate/test_integration_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import re

import pytest
import torch

# import torch
from pydantic import BaseModel, constr
from vllm.sampling_params import SamplingParams

Expand All @@ -11,9 +12,9 @@
import outlines.models as models
import outlines.samplers as samplers

pytestmark = pytest.mark.skipif(
not torch.cuda.is_available(), reason="vLLM models can only be run on GPU."
)
# pytestmark = pytest.mark.skipif(
# not torch.cuda.is_available(), reason="vLLM models can only be run on GPU."
# )


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -230,9 +231,6 @@ def test_vllm_json_schema(model):
assert isinstance(result["bar"], str)


@pytest.mark.xfail(
reason="The CFG logits processor for vLLM has not been implemented yet."
)
def test_vllm_cfg(model):
prompt = "<|im_start|>user\nOutput a short and valid JSON object with two keys.<|im_end|>\n><|im_start|>assistant\n"
result = generate.cfg(model, grammars.arithmetic)(prompt, seed=11)
Expand Down
Loading