Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto-apply chat template in SequenceGenerator and SequenceGeneratorAdapter, if available #1019

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion outlines/generate/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterator, List, Optional, Union

Expand All @@ -20,13 +21,15 @@ def __init__(
model,
sampler,
device,
apply_chat_template: bool = True,
):
self.fsm = fsm
self.model = model
self.sampler = sampler
self.tokenizer = model.tokenizer
self.device = device
self.num_samples = sampler.samples
self.apply_chat_template = apply_chat_template

def get_generated_token_ids(
self,
Expand Down Expand Up @@ -132,6 +135,7 @@ def __call__(
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
rng: Optional["torch.Generator"] = None,
apply_chat_template: Optional[bool] = None,
) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]:
"""Generate the full text sequence.

Expand All @@ -153,16 +157,25 @@ def __call__(
rng
The random number generator. Defaults to a non-seeded `torch.Generator`
instance.
apply_chat_template
Whether to apply the chat template to the prompts. Defaults to the value
set at init. Only applies to `TransformerTokenizer` for now.

Returns
-------
The generation(s), potentially cast to another type.
"""
if apply_chat_template is None:
apply_chat_template = self.apply_chat_template

import torch

if isinstance(prompts, str):
prompts = [prompts]

if apply_chat_template:
apply_chat_template_util(self.model, prompts)

if isinstance(stop_at, str):
stop_at = [stop_at]

Expand Down Expand Up @@ -250,6 +263,7 @@ def stream(
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
rng: Optional["torch.Generator"] = None,
apply_chat_template: Optional[bool] = None,
) -> Iterator[Union[List[str], str, List[List[str]]]]:
"""Generate the text sequence one token at a time.

Expand All @@ -270,17 +284,26 @@ def stream(
rng
The random number generator. Defaults to a non-seeded `torch.Generator`
instance.
apply_chat_template
Whether to apply the chat template to the prompts. Defaults to the value
set at init. Only applies to `TransformerTokenizer` for now.

Returns
-------
A string or list of strings that contain the generated text.

"""
if apply_chat_template is None:
apply_chat_template = self.apply_chat_template

import torch

if isinstance(prompts, str):
prompts = [prompts]

if apply_chat_template:
apply_chat_template_util(self.model, prompts)

if isinstance(stop_at, str):
stop_at = [stop_at]

Expand Down Expand Up @@ -423,7 +446,9 @@ class SequenceGeneratorAdapter:

"""

def __init__(self, model, logits_processor, sampler):
def __init__(
self, model, logits_processor, sampler, apply_chat_template: bool = True
):
self.model = model
self.logits_processor = logits_processor

Expand All @@ -444,6 +469,8 @@ def __init__(self, model, logits_processor, sampler):
"beam_search", sampler.samples, None, None, 1.0
)

self.apply_chat_template = apply_chat_template

def prepare_generation_parameters(
self,
max_tokens: Optional[int],
Expand Down Expand Up @@ -485,9 +512,15 @@ def __call__(
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
seed: Optional[int] = None,
apply_chat_template: Optional[bool] = None,
**model_specific_params,
):
"""Generate text from a prompt of list of prompts."""
if apply_chat_template is None:
apply_chat_template = self.apply_chat_template

if apply_chat_template:
apply_chat_template_util(self.model, prompts)

def format(sequences):
"""Apply formatting to every string in a completion."""
Expand Down Expand Up @@ -516,9 +549,14 @@ def stream(
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
seed: Optional[int] = None,
apply_chat_template: Optional[bool] = None,
**model_specific_params,
):
"""Return a text generator from a prompt or a list of prompts."""
if apply_chat_template is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this pythonic, we should have one obvious way of applying a chat template. IMO the argument should only be accepted in the constructor.

apply_chat_template = self.apply_chat_template
if apply_chat_template:
apply_chat_template_util(self.model, prompts)
generation_params = self.prepare_generation_parameters(
max_tokens, stop_at, seed
)
Expand All @@ -529,3 +567,22 @@ def stream(
self.sampling_params,
**model_specific_params,
)


def apply_chat_template_util(model, prompts: Union[str, List[str]]) -> List[str]:
from outlines.models.transformers import TransformerTokenizer

if isinstance(prompts, str):
prompts = [prompts]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the signature should be List[str] -> List[str] and raise an error if a list isn't passed. In transformers.py, this function is called after the prompts are normalized to 2D anyways.

if not isinstance(model.tokenizer, TransformerTokenizer):
warnings.warn(
"Chat template is only supported for `Transformer` models for now. The raw prompts will be used instead."
)
return prompts
tokenizer: "TransformerTokenizer" = model.tokenizer
if getattr(tokenizer.tokenizer, "chat_template", None) is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warnings.warn(
"The model does not have chat template support. The raw prompts will be used instead. To turn this warning off, either explicitly set the `apply_chat_template` argument to 'False' or assign a value to `model.tokenizer.tokenizer.chat_template`."
)
return prompts
return [tokenizer.apply_chat_template(prompt) for prompt in prompts]
14 changes: 11 additions & 3 deletions outlines/generate/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@


@singledispatch
def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenerator:
def cfg(
model,
cfg_str: str,
sampler: Sampler = multinomial(),
apply_chat_template: bool = True,
) -> SequenceGenerator:
"""Generate text in the language of a Context-Free Grammar

Arguments
Expand All @@ -29,7 +34,7 @@ def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenera
"""
fsm = CFGGuide(cfg_str, model.tokenizer)
device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)
generator = SequenceGenerator(fsm, model, sampler, device, apply_chat_template)

return generator

Expand All @@ -40,6 +45,7 @@ def cfg_unimplemented(
model,
cfg_str: str,
sampler: Sampler = multinomial(),
apply_chat_template: bool = True,
):
raise NotImplementedError(
f"The CFG Logits processor is not available for {type(model)}."
Expand All @@ -55,7 +61,9 @@ def cfg_llamacpp(
from outlines.integrations.llamacpp import CFGLogitsProcessor

logits_processor = CFGLogitsProcessor(cfg_str, model.model)
return SequenceGeneratorAdapter(model, logits_processor, sampler)
return SequenceGeneratorAdapter(
model, logits_processor, sampler, apply_chat_template=False
)


@cfg.register(OpenAI)
Expand Down
7 changes: 5 additions & 2 deletions outlines/generate/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@


def fsm(
model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial()
model,
fsm: interegular.fsm.FSM,
sampler: Sampler = multinomial(),
apply_chat_template: bool = True,
) -> SequenceGenerator:
fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer)
device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)
generator = SequenceGenerator(fsm, model, sampler, device, apply_chat_template)
return generator
23 changes: 18 additions & 5 deletions outlines/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@


@singledispatch
def regex(model, regex_str: str, sampler: Sampler = multinomial()):
def regex(
model,
regex_str: str,
sampler: Sampler = multinomial(),
apply_chat_template: bool = True,
):
"""Generate structured text in the language of a regular expression.

Parameters
Expand All @@ -33,7 +38,7 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()):
fsm = RegexGuide(regex_str, model.tokenizer)

device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)
generator = SequenceGenerator(fsm, model, sampler, device, apply_chat_template)

return generator

Expand All @@ -43,11 +48,14 @@ def regex_mlxlm(
model: MLXLM,
regex_str: str,
sampler: Sampler = multinomial(),
apply_chat_template: bool = True,
):
from outlines.processors import RegexLogitsProcessor

logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer)
return SequenceGeneratorAdapter(model, logits_processor, sampler)
return SequenceGeneratorAdapter(
model, logits_processor, sampler, apply_chat_template
)


@regex.register(LlamaCpp)
Expand All @@ -59,19 +67,24 @@ def regex_llamacpp(
from outlines.integrations.llamacpp import RegexLogitsProcessor

logits_processor = RegexLogitsProcessor(regex_str, llm=model.model)
return SequenceGeneratorAdapter(model, logits_processor, sampler)
return SequenceGeneratorAdapter(
model, logits_processor, sampler, apply_chat_template=False
)


@regex.register(VLLM)
def regex_vllm(
model: VLLM,
regex_str: str,
sampler: Sampler = multinomial(),
apply_chat_template: bool = True,
):
from outlines.integrations.vllm import RegexLogitsProcessor

logits_processor = RegexLogitsProcessor(regex_str, model.model)
return SequenceGeneratorAdapter(model, logits_processor, sampler)
return SequenceGeneratorAdapter(
model, logits_processor, sampler, apply_chat_template
)


@regex.register(OpenAI)
Expand Down
20 changes: 13 additions & 7 deletions outlines/generate/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@


@singledispatch
def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:
def text(
model, sampler: Sampler = multinomial(), apply_chat_template: bool = True
) -> SequenceGenerator:
"""Generate text with a `Transformer` model.

Note
Expand All @@ -31,24 +33,28 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator:
"""
fsm = StopAtEOSGuide(model.tokenizer)
device = model.device
generator = SequenceGenerator(fsm, model, sampler, device)
generator = SequenceGenerator(fsm, model, sampler, device, apply_chat_template)

return generator


@text.register(MLXLM)
def text_mlxlm(model: MLXLM, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)
def text_mlxlm(
model: MLXLM, sampler: Sampler = multinomial(), apply_chat_template: bool = True
):
return SequenceGeneratorAdapter(model, None, sampler, apply_chat_template)


@text.register(VLLM)
def text_vllm(model: VLLM, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)
def text_vllm(
model: VLLM, sampler: Sampler = multinomial(), apply_chat_template: bool = True
):
return SequenceGeneratorAdapter(model, None, sampler, apply_chat_template)


@text.register(LlamaCpp)
def text_llamacpp(model: LlamaCpp, sampler: Sampler = multinomial()):
return SequenceGeneratorAdapter(model, None, sampler)
return SequenceGeneratorAdapter(model, None, sampler, apply_chat_template=False)


@text.register(OpenAI)
Expand Down
Loading
Loading