Skip to content

Commit

Permalink
Basic proposition implementation token alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinPicard committed Jan 11, 2024
1 parent 37f53ca commit a04e8d4
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 4 deletions.
43 changes: 42 additions & 1 deletion outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from copy import deepcopy
from typing import TYPE_CHECKING, List, NewType, Protocol

import cloudpickle
import interegular
from lark import Lark

Expand Down Expand Up @@ -119,9 +121,43 @@ def __init__(self, regex_string: str, tokenizer: "Tokenizer"):
self.final_states = regex_fsm.finals | {
-1
} # Include the EOS token in final states
self.tokenizer = tokenizer
self.vocabulary = tokenizer.vocabulary.values()
self.end_token_id = tokenizer.eos_token_id

def align_prompt_tokens(self, prompt: str) -> str:
"""Remove the last token from the prompt and update the states_to_token_maps accordingly"""
token_ids, _ = self.tokenizer.encode(prompt)
last_token_id = int(token_ids[0][-1])
last_token_text = self.tokenizer.decode([last_token_id])[0]
vocabulary = {
self.tokenizer.decode([token_id])[0]: token_id
for token_id in range(len(self.vocabulary))
}
starting_state_tokens = {
self.tokenizer.decode([token_id])[0]: self.states_to_token_maps[0][token_id]
for token_id in self.states_to_token_maps[0]
}
# select the tokens that start with the text removed from the prompt and whose text after the
# initial prompt corresponds to that of one of the allowed tokens of the starting state
possible_tokens = {
vocabulary[token_text]: starting_state_tokens[token_text[len(last_token_text):]]
for token_text in vocabulary
if (
token_text.startswith(last_token_text)
and starting_state_tokens.get(token_text[len(last_token_text):])
)
}
# update the states_to_token_maps in the following manner:
# the value of the starting state is assigned to a new state, the starting state is now the
# possible_tokens found above + the last_token we removed (that leads to the new state)
additional_state_id = max(list(self.states_to_token_maps.keys()) + list(self.final_states)) + 1
self.states_to_token_maps[additional_state_id] = self.states_to_token_maps[0]
self.states_to_token_maps[0] = {**possible_tokens, last_token_id: additional_state_id}

return prompt[:-len(last_token_text)]


def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
Expand Down Expand Up @@ -186,7 +222,12 @@ def is_final_state(self, state: FSMState) -> bool:

def copy(self) -> "RegexFSM":
"""Create a copy of the FSM."""
return self
# temporary solution to the problem of unpickleable dict_values
self.vocabulary = cloudpickle.dumps(self.vocabulary)
copy = deepcopy(self)
self.vocabulary = cloudpickle.loads(self.vocabulary)
copy.vocabulary = cloudpickle.loads(copy.vocabulary)
return copy


class CFGFSM(FSM):
Expand Down
18 changes: 15 additions & 3 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ def get_generated_token_ids(

return token_ids

def get_generated_sequences(
self, generated_token_ids: List[torch.Tensor], initial_prompts: List[str], prompts: List[str]
) -> List[str]:
"""Give the text sequences generated based on the tokens generated and the initial prompts"""
generated_tokens_text = self.tokenizer.decode(generated_token_ids)
return [
generated_tokens_text[i][len(initial_prompts[i]) - len(prompts[i]):]
for i in range(len(generated_tokens_text))
]

def is_stop_sequence_found(
self, generated_sequences: List[str], stop_sequences: List[str]
) -> bool:
Expand Down Expand Up @@ -186,6 +196,7 @@ def __call__(

if isinstance(prompts, str):
prompts = [prompts]
initial_prompts = copy.deepcopy(prompts)

if isinstance(stop_at, str):
stop_at = [stop_at]
Expand All @@ -194,6 +205,7 @@ def __call__(
max_tokens = max_tokens or self.max_tokens
num_sequences = len(prompts)
fsms = [self.fsm.copy() for _ in prompts]
prompts = [fsm.align_prompt_tokens(prompt) for fsm, prompt in zip(fsms, prompts)]

if rng is None:
rng = torch.Generator(device=self.device)
Expand All @@ -213,7 +225,7 @@ def __call__(
last_state = next(states)
if max_tokens or stop_sequences:
generated_token_ids = self.get_generated_token_ids(
init_state, prompts, last_state
init_state, initial_prompts, last_state
)
if max_tokens and len(generated_token_ids[0]) >= max_tokens:
break
Expand All @@ -225,9 +237,9 @@ def __call__(
break

generated_token_ids = self.get_generated_token_ids(
init_state, prompts, last_state
init_state, initial_prompts, last_state
)
generated = self.tokenizer.decode(generated_token_ids)
generated = self.get_generated_sequences(generated_token_ids, initial_prompts, prompts)
stripped = [
self.strip_stop_sequences(sequence, stop_sequences)
for sequence in generated
Expand Down

0 comments on commit a04e8d4

Please sign in to comment.