Skip to content

Commit

Permalink
Rename and document new methods
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Dec 23, 2023
1 parent aa7b18f commit e7e9e08
Showing 1 changed file with 88 additions and 29 deletions.
117 changes: 88 additions & 29 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@


class SequenceGenerator:
def __init__(self, fsm, model, sampler, device, stop_at=None, max_tokens=None):
def __init__(self, fsm, model, sampler, device, max_tokens=None, stop_at=None):
self.generate_token = token_generator(model, sampler)
self.fsm = fsm
self.tokenizer = model.tokenizer
self.device = device
self.max_tokens = max_tokens

if isinstance(stop_at, str):
stop_at = [stop_at]
self.stop_sequences = stop_at
Expand All @@ -49,30 +50,70 @@ def get_generated_token_ids(
prompts: List[str],
last_state: GenerationState,
) -> List[torch.Tensor]:
"""Give the tokens generated (so the current sequences without the initial user prompts)"""
# Get the number of tokens in the prompts
"""Get the tokens generated so far.
Parameters
----------
init_state
The initial state of the generation.
prompts
The prompts passed to the generator.
last_state
The current state of the generation
Returns
-------
A tensor that contains the token ids that have been generated so far.
"""
prompt_token_ids = init_state[0]
prompt_lengths = [len(prompt_token_ids[i]) for i in range(len(prompts))]
# Remove the prompts from the generated sequences

token_ids = [
cur_token_ids[length:]
for cur_token_ids, length in zip(last_state.token_ids, prompt_lengths)
]

return token_ids

def is_stop_sequence_reached(
def is_stop_sequence_found(
self, generated_sequences: List[str], stop_sequences: List[str]
) -> bool:
"""True if at least one of the stop sequences is found in each generated sequence"""
"""Determine whether one of the stop sequences has been generated.
Parameters
----------
generated_sequences
The list of sequences generated so far.
stop_sequences
The list that contains the sequence which stop the generation when
found.
Returns
-------
True if at least one of the stop sequences has been found in each generated
sequence.
"""
return all(
[
any([seq in generated for seq in stop_sequences])
for generated in generated_sequences
]
)

def format_sequence(self, sequence: str, stop_sequences: List[str]) -> str:
"""Format the text sequence generated before returning it to the user"""
def strip_stop_sequences(self, sequence: str, stop_sequences: List[str]) -> str:
"""Remove the stop sequences from the generated sequences.
Parameters
----------
sequence
One of the generated sequences.
stop_sequences
The list that contains the sequence which stop the generation when
found.
"""
if stop_sequences:
match_indexes = [sequence.find(seq) for seq in stop_sequences]
if any([index != -1 for index in match_indexes]):
Expand All @@ -83,10 +124,25 @@ def format_sequence(self, sequence: str, stop_sequences: List[str]) -> str:
: match_indexes[min_match_index_pos]
+ len(stop_sequences[min_match_index_pos])
]
return self.structure_sequence(sequence)

def structure_sequence(self, sequence: str) -> str:
"""Modify the structure/type of the sequence, is overriden in some generate functions"""
return sequence

def format_sequence(self, sequence: str) -> str:
"""Translate the generated sequence to another type.
This method is for instance overridden when generating JSON to either
return a dictionnary or a Pydantic model.
Parameters
----------
sequence
A generated sequences.
Returns
-------
The formatted sequence.
"""
return sequence

def __call__(
Expand Down Expand Up @@ -135,9 +191,10 @@ def __call__(

if isinstance(stop_at, str):
stop_at = [stop_at]
stop_sequences = stop_at or self.stop_sequences

stop_sequences = stop_at or self.stop_sequences
max_tokens = max_tokens or self.max_tokens
num_sequences = len(prompts)

if rng is None:
rng = torch.Generator(device=self.device)
Expand All @@ -146,7 +203,6 @@ def __call__(
init_state = init_generator_state(
self.tokenizer, self.device, prompts, kv_cache
)
num_sequences = len(prompts)
init_fsm_states = [FSMState(0) for _ in range(num_sequences)]

states = sequence_generator(
Expand All @@ -157,25 +213,28 @@ def __call__(
try:
last_state = next(states)
if max_tokens or stop_sequences:
token_ids = self.get_generated_token_ids(
generated_token_ids = self.get_generated_token_ids(
init_state, prompts, last_state
)
if max_tokens and len(token_ids[0]) >= max_tokens:
if max_tokens and len(generated_token_ids[0]) >= max_tokens:
break
if stop_sequences and self.is_stop_sequence_reached(
self.tokenizer.decode(token_ids), stop_sequences
if stop_sequences and self.is_stop_sequence_found(
self.tokenizer.decode(generated_token_ids), stop_sequences
):
break
except StopIteration:
break

token_ids = self.get_generated_token_ids(init_state, prompts, last_state)
generated = self.tokenizer.decode(token_ids)

generated_token_ids = self.get_generated_token_ids(
init_state, prompts, last_state
)
generated = self.tokenizer.decode(generated_token_ids)
stripped = [
self.strip_stop_sequences(sequence, stop_sequences)
for sequence in generated
]
try:
formatted = [
self.format_sequence(sequence, stop_sequences) for sequence in generated
]
formatted = [self.format_sequence(sequence) for sequence in stripped]
except pyjson.decoder.JSONDecodeError:
raise TypeError(
"Could not format the output of the model into a dictionary or a Pydantic model."
Expand Down Expand Up @@ -277,7 +336,7 @@ def token_generator() -> Iterator[Union[List[str], str]]:
if stop_sequences:
is_stop_at_reached = [
stop
or self.is_stop_sequence_reached(
or self.is_stop_sequence_found(
[generated_sequence], stop_sequences
)
for generated_sequence, stop in zip(
Expand All @@ -301,7 +360,7 @@ def text(

device = model.device
generator = SequenceGenerator(
fsm, model, sampler, device, stop_at=stop_at, max_tokens=max_tokens
fsm, model, sampler, device, max_tokens=max_tokens, stop_at=stop_at
)

return generator
Expand Down Expand Up @@ -332,7 +391,7 @@ def cfg(

device = model.device
generator = SequenceGenerator(
fsm, model, sampler, device, stop_at=stop_at, max_tokens=max_tokens
fsm, model, sampler, device, max_tokens=max_tokens, stop_at=stop_at
)

return generator
Expand Down Expand Up @@ -368,17 +427,17 @@ def json(
schema = pyjson.dumps(schema_object.model_json_schema())
regex_str = build_regex_from_object(schema)
generator = regex(model, regex_str, max_tokens, sampler)
generator.structure_sequence = lambda x: schema_object.parse_raw(x)
generator.format_sequence = lambda x: schema_object.parse_raw(x)
elif callable(schema_object):
schema = pyjson.dumps(get_schema_from_signature(schema_object))
regex_str = build_regex_from_object(schema)
generator = regex(model, regex_str, max_tokens, sampler)
generator.structure_sequence = lambda x: pyjson.loads(x)
generator.format_sequence = lambda x: pyjson.loads(x)
elif isinstance(schema_object, str):
schema = schema_object
regex_str = build_regex_from_object(schema)
generator = regex(model, regex_str, max_tokens, sampler)
generator.structure_sequence = lambda x: pyjson.loads(x)
generator.format_sequence = lambda x: pyjson.loads(x)
else:
raise ValueError(
f"Cannot parse schema {schema_object}. The schema must be either "
Expand Down

0 comments on commit e7e9e08

Please sign in to comment.