Skip to content

Commit

Permalink
Align prompt and generation
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinPicard authored and rlouf committed Mar 11, 2024
1 parent 11143df commit 29853ec
Show file tree
Hide file tree
Showing 4 changed files with 582 additions and 23 deletions.
233 changes: 224 additions & 9 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Protocol, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Protocol, Tuple, Union

import interegular
import torch
from lark import Lark

from outlines import grammars
Expand Down Expand Up @@ -62,11 +65,16 @@ def get_next_state(self, state: int, token_id: int) -> int:
def is_final_state(self, state: int) -> bool:
...

def align_prompt_tokens(
self, token_ids: torch.Tensor, attention_masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
...


class StopAtEOSGuide(Guide):
"""Guide to generate tokens until the EOS token has been generated."""

final_state = 1
final_state = -1
start_state = 0

def __init__(self, tokenizer: "Tokenizer"):
Expand All @@ -77,24 +85,52 @@ def __init__(self, tokenizer: "Tokenizer"):
"""
self.eos_token_id = tokenizer.eos_token_id
self.vocabulary = tokenizer.vocabulary.values()
self.vocabulary = tokenizer.vocabulary
self.tokenizer = tokenizer
self.states_to_token_maps = self.create_states_to_tokens_map()

def create_states_to_tokens_map(self) -> Dict[int, Dict[int, int]]:
"""Create the states_to_tokens_map. All tokens from the starting state lead
to itself, except for the eos_token that leads to the final state."""
return {
self.start_state: {
token_id: self.start_state
if token_id != self.eos_token_id
else self.final_state
for token_id in self.vocabulary.values()
}
}

def align_prompt_tokens(
self, token_ids: torch.Tensor, attention_masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Update the states_to_token_maps and return the aligned prompt tokens and attention masks"""
(
token_ids,
attention_masks,
self.states_to_token_maps,
) = align_tokens_states_to_token_maps(
token_ids, attention_masks, self.vocabulary, self.states_to_token_maps
)
return token_ids, attention_masks

def get_next_instruction(self, state: int) -> Instruction:
if self.is_final_state(state):
return Write([self.eos_token_id])
return Generate(list(self.vocabulary))

return Generate(list(self.states_to_token_maps[state].keys()))

def get_next_state(self, state: int, token_id: int) -> int:
if token_id == self.eos_token_id or state == self.final_state:
if self.is_final_state(state):
return self.final_state

return self.start_state
return self.states_to_token_maps[state][token_id]

def is_final_state(self, state: int):
return state == self.final_state

def copy(self):
return self
return deepcopy(self)


class RegexGuide(Guide):
Expand Down Expand Up @@ -136,10 +172,23 @@ def create_states_mapping(
) = create_states_mapping(
regex_string, tuple(sorted(tokenizer.vocabulary.items()))
)
self.vocabulary = list(tokenizer.vocabulary.values())
self.vocabulary = tokenizer.vocabulary
self.eos_token_id = tokenizer.eos_token_id
self.final_states = fsm_finals | {-1}

def align_prompt_tokens(
self, token_ids: torch.Tensor, attention_masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Update the states_to_token_maps and return the aligned prompt tokens and attention masks"""
(
token_ids,
attention_masks,
self.states_to_token_maps,
) = align_tokens_states_to_token_maps(
token_ids, attention_masks, self.vocabulary, self.states_to_token_maps
)
return token_ids, attention_masks

def get_next_instruction(self, state: int) -> Instruction:
"""Return the next instruction for guided generation.
Expand Down Expand Up @@ -244,7 +293,7 @@ def is_final_state(self, state: int) -> bool:
return state in self.final_states

def copy(self):
return self
return deepcopy(self)


class CFGGuide(Guide):
Expand Down Expand Up @@ -281,6 +330,12 @@ def __init__(self, cfg_string: str, tokenizer):
self.start_state = 0
self.final_state = -1

def align_prompt_tokens(
self, token_ids: torch.Tensor, attention_masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Not applicable to this type of FSM"""
return token_ids, attention_masks

def get_next_instruction(self, state: int) -> Instruction:
"""Generate an instruction for the next step.
Expand Down Expand Up @@ -416,3 +471,163 @@ def is_final_state(self, state: int) -> bool:
def copy(self) -> "CFGGuide":
"""Create a copy of the FSM."""
return CFGGuide(self.cfg_string, self.tokenizer)


def align_tokens_states_to_token_maps(
token_ids: torch.Tensor,
attention_masks: torch.Tensor,
vocabulary: Dict[str, int],
states_to_token_maps: Dict[int, Dict[int, int]],
) -> Tuple[torch.Tensor, torch.Tensor, Dict[int, Dict[int, int]]]:
"""Apply token alignment to the provided prompt tokens and attention masks given the
states_to_token_maps of a FSM. Return the updated tokens/maps as well as the updated
states_to_token_maps"""
prompt_token_ids = token_ids.tolist()
crossing_tokens = find_crossing_tokens(prompt_token_ids, vocabulary)
valid_crossing_tokens = get_crossing_tokens_target_states(
states_to_token_maps, crossing_tokens, prompt_token_ids, vocabulary
)
if not valid_crossing_tokens:
return token_ids, attention_masks, states_to_token_maps
(
states_to_token_maps,
number_cropped_tokens,
) = add_crossing_tokens_states_to_tokens_map(
states_to_token_maps, prompt_token_ids, valid_crossing_tokens
)
return (
token_ids[:-number_cropped_tokens],
attention_masks[:-number_cropped_tokens],
states_to_token_maps,
)


def find_crossing_tokens(
token_ids: List[int], vocabulary: Dict[str, int]
) -> Dict[int, List[int]]:
"""Find the tokens that could replace one or more tokens at the end of token_ids
while conserving the same intial text (and extending it by at least one character).
Return a dictionary with, for the indexes in the token_ids with matches, the associated crossing tokens.
"""
reversed_vocabulary = {value: key for key, value in vocabulary.items()}
len_token_ids = len(token_ids)
max_length_token_text = max(len(item) for item in vocabulary.keys())
characters_considered = ""
crossing_tokens_map = {}

for index, token_id in enumerate(reversed(token_ids)):
characters_considered = reversed_vocabulary[token_id] + characters_considered
if len(characters_considered) >= max_length_token_text:
break
crossing_token_ids = [
token_id
for text, token_id in vocabulary.items()
if text.startswith(characters_considered)
and len(text) > len(characters_considered)
]
if crossing_token_ids:
crossing_tokens_map[len_token_ids - index - 1] = crossing_token_ids

return crossing_tokens_map


def get_crossing_tokens_target_states(
states_to_tokens_map: Dict[int, Dict[int, int]],
crossing_tokens: Dict[int, List[int]],
prompt_token_ids: List[int],
vocabulary: Dict[str, int],
) -> Dict[int, Dict[int, int]]:
"""For each crossing token associated to an index, check that the characters after the boundary
match the states_to_tokens_map and find the state it would lead to. Return a dict with, for each
provided indexes, the associated valid tokens with the state they would lead to.
"""
reversed_vocabulary = {value: key for key, value in vocabulary.items()}
prompt_token_texts = [
reversed_vocabulary[token_id] for token_id in prompt_token_ids
]

valid_crossing_tokens: Dict[int, Dict[int, int]] = defaultdict(dict)
for pos, tokens in crossing_tokens.items():
for token in tokens:
is_valid = True
characters = reversed_vocabulary[token]
characters_before_border = "".join(prompt_token_texts[pos:])
characters_after_border = characters[len(characters_before_border) :]
state = 0
for char in characters_after_border:
char_token = vocabulary.get(char)
try:
state = states_to_tokens_map[state][char_token] # type: ignore
except KeyError:
is_valid = False
break
if is_valid:
valid_crossing_tokens[pos][token] = state

return valid_crossing_tokens


def add_crossing_tokens_states_to_tokens_map(
states_to_tokens_map: Dict[int, Dict[int, int]],
prompt_token_ids: List[int],
crossing_tokens_map: Dict[int, Dict[int, int]],
) -> Tuple[Dict[int, Dict[int, int]], int]:
"""Modify the states_to_tokens_map to account for the crossing tokens. This operation modifies
the starting state of the fsm as we would include some characters at the end of the prompt in
the states_to_tokens_map.
Attention! the starting state of the states_to_tokens_map provided must be 0.
Return the updated states_to_tokens_map and the number of cropped tokens/additional states
"""
if not crossing_tokens_map:
return states_to_tokens_map, 0
first_crossing_token_pos = min(
[key for key, value in crossing_tokens_map.items() if value]
)
number_additional_states = len(prompt_token_ids) - first_crossing_token_pos
highest_state = max(
max(states_to_tokens_map.keys()),
max(max(items.values()) for items in states_to_tokens_map.values()),
)

for i in range(number_additional_states):
# add the tokens that was originally part of the prompt
if i == number_additional_states - 1:
states_to_tokens_map[highest_state + 1 + i] = {
prompt_token_ids[first_crossing_token_pos + i]: 0
}
else:
states_to_tokens_map[highest_state + 1 + i] = {
prompt_token_ids[first_crossing_token_pos + i]: highest_state + 2 + i
}
# add the crossing tokens
crossing_tokens = crossing_tokens_map.get(first_crossing_token_pos + i)
if crossing_tokens:
for token, target_state in crossing_tokens.items():
states_to_tokens_map[highest_state + 1 + i][token] = target_state

# set the id of our new initial state to 0
states_to_tokens_map = swap_state_ids_states_to_tokens_map(
states_to_tokens_map, highest_state + 1, 0
)
return states_to_tokens_map, number_additional_states


def swap_state_ids_states_to_tokens_map(
states_to_tokens_map: Dict[int, Dict[int, int]],
first_state_id: int,
second_state_id: int,
) -> Dict[int, Dict[int, int]]:
"""Swap the id of two states of the states_to_tokens_map while conserving all transitions"""
first_state_transitions = states_to_tokens_map.pop(first_state_id)
second_state_transitions = states_to_tokens_map.pop(second_state_id)
states_to_tokens_map[first_state_id] = second_state_transitions
states_to_tokens_map[second_state_id] = first_state_transitions

for transitions in states_to_tokens_map.values():
for token, target_state_id in list(transitions.items()):
if target_state_id == first_state_id:
transitions[token] = second_state_id
elif target_state_id == second_state_id:
transitions[token] = first_state_id

return states_to_tokens_map
Loading

0 comments on commit 29853ec

Please sign in to comment.