Skip to content

Commit

Permalink
Add integration for transformers via logits processors
Browse files Browse the repository at this point in the history
  • Loading branch information
saattrupdan authored and rlouf committed Mar 12, 2024
1 parent 11143df commit 5d97ee1
Show file tree
Hide file tree
Showing 16 changed files with 763 additions and 319 deletions.
25 changes: 25 additions & 0 deletions examples/transformers_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Example of integrating `outlines` with `transformers`."""

from pydantic import BaseModel
from transformers import pipeline

from outlines.integrations.transformers import JSONPrefixAllowedTokens


class Person(BaseModel):
first_name: str
surname: str


pipe = pipeline("text-generation", model="mistralai/Mistral-7B-v0.1")
prefix_allowed_tokens_fn = JSONPrefixAllowedTokens(
schema=Person, tokenizer_or_pipe=pipe, whitespace_pattern=r" ?"
)
results = pipe(
["He is Tom Jones", "She saw Linda Smith"],
return_full_text=False,
do_sample=False,
max_new_tokens=50,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)
print(results)
20 changes: 12 additions & 8 deletions examples/vllm_integration.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
"""Example of integrating `outlines` with `vllm`."""

import vllm
from pydantic import BaseModel

from outlines.serve.vllm import JSONLogitsProcessor
from outlines.integrations.vllm import JSONLogitsProcessor


class User(BaseModel):
id: int
name: str
class Person(BaseModel):
first_name: str
surname: str


llm = vllm.LLM(model="openai-community/gpt2")
logits_processor = JSONLogitsProcessor(schema=User, llm=llm)
llm = vllm.LLM(model="mistralai/Mistral-7B-v0.1", max_model_len=512)
logits_processor = JSONLogitsProcessor(schema=Person, llm=llm, whitespace_pattern=r" ?")
result = llm.generate(
["A prompt", "Another prompt"],
["He is Tom Jones", "She saw Linda Smith"],
sampling_params=vllm.SamplingParams(
max_tokens=100, logits_processors=[logits_processor]
temperature=0.0,
max_tokens=50,
logits_processors=[logits_processor],
),
)
print(result)
7 changes: 6 additions & 1 deletion outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def get_next_state(self, state: int, token_id: int) -> int:
def is_final_state(self, state: int) -> bool:
...

def copy(self) -> "Guide":
...


class StopAtEOSGuide(Guide):
"""Guide to generate tokens until the EOS token has been generated."""
Expand Down Expand Up @@ -189,7 +192,9 @@ def get_next_state(self, state: int, token_id: int) -> int:
"""
if token_id == self.eos_token_id:
return -1
elif state in self.final_states:
elif (
state in self.final_states
): # Necessary because we keep generating EOS tokens when finished
return state

last_token_to_end_state = self.states_to_token_maps[state]
Expand Down
43 changes: 33 additions & 10 deletions outlines/fsm/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
from typing import Any, Callable, Tuple
from typing import Protocol, Tuple, Type, Union

INTEGER = r"[+-]?(0|[1-9][0-9]*)"
BOOLEAN = "(True|False)"
Expand All @@ -9,26 +9,49 @@
DATETIME = rf"({DATE})(\s)({TIME})"


def python_types_to_regex(python_type: Any) -> Tuple[str, Callable[[str], Any]]:
class FormatFunction(Protocol):
def __call__(
self, sequence: str
) -> Union[int, float, bool, datetime.date, datetime.time, datetime.datetime]:
...


def python_types_to_regex(python_type: Type) -> Tuple[str, FormatFunction]:
if python_type == float:
float_format_fn = lambda x: float(x)

def float_format_fn(sequence: str) -> float:
return float(sequence)

return FLOAT, float_format_fn
elif python_type == int:
int_format_fn = lambda x: int(x)

def int_format_fn(sequence: str) -> int:
return int(sequence)

return INTEGER, int_format_fn
elif python_type == bool:
bool_format_fn = lambda x: bool(x)

def bool_format_fn(sequence: str) -> bool:
return bool(sequence)

return BOOLEAN, bool_format_fn
elif python_type == datetime.date:
date_format_fn = lambda s: datetime.datetime.strptime(s, "%Y-%m-%d").date()

def date_format_fn(sequence: str) -> datetime.date:
return datetime.datetime.strptime(sequence, "%Y-%m-%d").date()

return DATE, date_format_fn
elif python_type == datetime.time:
time_format_fn = lambda s: datetime.datetime.strptime(s, "%H:%M:%S").time()

def time_format_fn(sequence: str) -> datetime.time:
return datetime.datetime.strptime(sequence, "%H:%M:%S").time()

return TIME, time_format_fn
elif python_type == datetime.datetime:
datetime_format_fn = lambda s: datetime.datetime.strptime(
s, "%Y-%m-%d %H:%M:%S"
)

def datetime_format_fn(sequence: str) -> datetime.datetime:
return datetime.datetime.strptime(sequence, "%Y-%m-%d %H:%M:%S")

return DATETIME, datetime_format_fn
else:
raise NotImplementedError(
Expand Down
18 changes: 11 additions & 7 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import datetime
from typing import Iterator, List, Optional, Union

import torch

from outlines.generate.generator import sequence_generator

FormattedOutput = Union[
str, int, float, bool, datetime.date, datetime.time, datetime.datetime
]


class SequenceGenerator:
def __init__(
Expand Down Expand Up @@ -100,7 +105,7 @@ def strip_stop_sequences(

return sequence

def format_sequence(self, sequence: str) -> str:
def format_sequence(self, sequence: str) -> FormattedOutput:
"""Translate the generated sequence to another type.
This method is for instance overridden when generating JSON to either
Expand All @@ -124,7 +129,7 @@ def __call__(
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
rng: Optional[torch.Generator] = None,
) -> Union[str, List[str], List[List[str]]]:
) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]:
"""Generate the full text sequence.
Since `SequenceGenerator.stream` calls the tokenizer at every step this
Expand All @@ -148,8 +153,7 @@ def __call__(
Returns
-------
A string or list of strings that contain the generated text.
The generation(s), potentially cast to another type.
"""

if isinstance(prompts, str):
Expand Down Expand Up @@ -222,7 +226,7 @@ def __call__(
formatted = [self.format_sequence(sequence) for sequence in stripped]

# We reshape the output to (batch_size, sample_size)
output = []
output: List[List[FormattedOutput]] = list()
for i in range(batch_size):
output.append(formatted[i : i + num_samples])

Expand All @@ -242,7 +246,7 @@ def stream(
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
rng: Optional[torch.Generator] = None,
) -> Iterator[Union[List[str], List[List[str]], str]]:
) -> Iterator[Union[List[str], str, List[List[str]]]]:
"""Generate the text sequence one token at a time.
Since `Tokenizer.decode` strips the whitespaces from the tokens we have no
Expand Down Expand Up @@ -352,7 +356,7 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
]

# We reshape the output to (batch_size, sample_size)
output = []
output: List[List[str]] = list()
for i in range(batch_size):
output.append(next_tokens[i : i + num_samples])

Expand Down
11 changes: 4 additions & 7 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@

from outlines.fsm.guide import RegexGuide
from outlines.generate.api import SequenceGenerator
from outlines.integrations.llamacpp import RegexLogitsProcessor
from outlines.models import OpenAI
from outlines.models.llamacpp import (
LlamaCpp,
LlamaSequenceGenerator,
RegexLogitsProcessor,
)
from outlines.models.llamacpp import LlamaCpp, LlamaSequenceGenerator
from outlines.samplers import Sampler, multinomial


Expand Down Expand Up @@ -52,8 +49,8 @@ def regex_llamacpp(
+ "than the multinomial sampler."
)

logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer)
generator = LlamaSequenceGenerator(logits_processor, model)
logits_processor = RegexLogitsProcessor(regex_str, llm=model.model)
generator = LlamaSequenceGenerator(logits_processor=logits_processor, model=model)

return generator

Expand Down
1 change: 1 addition & 0 deletions outlines/integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Utility functions and classes used to integrate `outlines` into other packages."""
Loading

0 comments on commit 5d97ee1

Please sign in to comment.