Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the FSM module #734

Merged
merged 2 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions outlines/fsm/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,8 @@


class FSM(Protocol):
first_state: FSMState = FSMState(0)
final_state: FSMState = FSMState(-1)

def is_final_state(self, state: FSMState) -> bool:
"""Determine whether the current state of the FSM is a final state."""
return state == self.final_state
...

def allowed_token_ids(self, state: FSMState) -> List[int]:
...
Expand All @@ -32,12 +28,14 @@ def copy(self) -> "FSM":
...


class StopAtEosFSM(FSM):
class StopAtEosFSM:
"""FSM to generate text until EOS has been generated."""

def __init__(self, tokenizer: "Tokenizer"):
self.eos_token_id = tokenizer.eos_token_id
self.vocabulary = tokenizer.vocabulary.values()
self.start_state: FSMState = FSMState(0)
self.final_state: FSMState = FSMState(1)

def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
Expand Down Expand Up @@ -81,21 +79,25 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
if token_id == self.eos_token_id or state == self.final_state:
return self.final_state

return self.first_state
return self.start_state

def is_final_state(self, state: FSMState) -> bool:
"""Determine whether the current state of the FSM is a final state."""
return state == self.final_state

def copy(self) -> "StopAtEosFSM":
"""Create a copy of the FSM."""
return self


class RegexFSM(FSM):
class RegexFSM:
"""FSM to generate text that is in the language of a regular expression."""

def __init__(self, regex_string: str, tokenizer):
@cache()
def create_states_mapping(
regex_string: str, cacheable_vocabulary: Tuple[Tuple[str, int], ...]
) -> Tuple[dict, set]:
) -> Tuple[dict, set, set]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
Expand All @@ -116,13 +118,19 @@ def create_states_mapping(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)

return states_to_token_maps, empty_token_ids
return states_to_token_maps, empty_token_ids, regex_fsm.finals

self.states_to_token_maps, self.empty_token_ids = create_states_mapping(
(
self.states_to_token_maps,
self.empty_token_ids,
fsm_finals,
) = create_states_mapping(
regex_string, tuple(sorted(tokenizer.vocabulary.items()))
)
self.vocabulary = list(tokenizer.vocabulary.values())
self.eos_token_id = tokenizer.eos_token_id
self.start_state = FSMState(0)
self.final_states = fsm_finals | {-1}

def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.
Expand Down Expand Up @@ -172,13 +180,17 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
The new state of the FSM.

"""
if token_id == self.eos_token_id or state == self.final_state:
return self.final_state
if token_id == self.eos_token_id:
return FSMState(-1)
elif (
state in self.final_states
): # Necessary because we keep generating EOS tokens when finished
return state

last_token_to_end_state = self.states_to_token_maps[state]
next_state = last_token_to_end_state.get(token_id)
if next_state is None:
return self.final_state
return FSMState(-1)

return FSMState(next_state)

Expand Down Expand Up @@ -222,6 +234,9 @@ def create_states_mapping_from_interegular_fsm(
from_interegular_instance.eos_token_id = tokenizer.eos_token_id
return from_interegular_instance

def is_final_state(self, state: FSMState) -> bool:
return state in self.final_states

def copy(self) -> "RegexFSM":
"""Create a copy of the FSM."""
return self
Expand Down Expand Up @@ -258,6 +273,9 @@ def __init__(self, cfg_string: str, tokenizer):
self.proposal_last: List[int] = []
self.regex_fsm_last: RegexFSM

self.start_state = FSMState(0)
self.final_state = FSMState(-1)

def allowed_token_ids(self, state: FSMState) -> List[int]:
"""Generate a list of allowed tokens for the next step.

Expand Down Expand Up @@ -328,7 +346,7 @@ def allowed_token_ids(self, state: FSMState) -> List[int]:
self.regex_fsm = RegexFSM(regex_string, self.tokenizer)
self.reset_state = True

proposal += self.regex_fsm.allowed_token_ids(self.first_state)
proposal += self.regex_fsm.allowed_token_ids(self.start_state)
if self.allow_eos:
self.allow_eos = False
else:
Expand All @@ -354,6 +372,9 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:
-------
The new state of the FSM.
"""

# We need to return the final state when in the final state because we
# then generate EOS tokens instead of stopping the generation.
if token_id == self.tokenizer.eos_token_id or state == self.final_state:
return self.final_state

Expand All @@ -366,10 +387,13 @@ def next_state(self, state: FSMState, token_id: int) -> FSMState:

if self.reset_state:
self.reset_state = False
state = self.first_state
state = self.start_state

return self.regex_fsm.next_state(state, token_id)

def is_final_state(self, state: FSMState) -> bool:
return state == self.final_state

def copy(self) -> "CFGFSM":
"""Create a copy of the FSM."""
return CFGFSM(self.cfg_string, self.tokenizer)
55 changes: 41 additions & 14 deletions tests/fsm/test_fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ class MockTokenizer:

fsm = StopAtEosFSM(MockTokenizer())

assert fsm.allowed_token_ids(fsm.first_state) == [1, 2]
assert fsm.allowed_token_ids(fsm.start_state) == [1, 2]
assert fsm.allowed_token_ids(fsm.final_state) == [2]
assert fsm.next_state(fsm.first_state, 2) == fsm.final_state
assert fsm.next_state(fsm.first_state, 1) == fsm.first_state
assert fsm.is_final_state(fsm.first_state) is False
assert fsm.next_state(fsm.start_state, 2) == fsm.final_state
assert fsm.next_state(fsm.start_state, 1) == fsm.start_state
assert fsm.is_final_state(fsm.start_state) is False
assert fsm.is_final_state(fsm.final_state) is True


Expand Down Expand Up @@ -49,10 +49,37 @@ def convert_token_to_string(self, token):
assert fsm.states_to_token_maps == {0: {1: 1}}
assert fsm.allowed_token_ids(state=0) == [1]
assert fsm.next_state(state=0, token_id=1) == 1
assert fsm.next_state(state=0, token_id=tokenizer.eos_token_id) == fsm.final_state
assert fsm.next_state(state=0, token_id=tokenizer.eos_token_id) == -1

assert fsm.is_final_state(0) is False
assert fsm.is_final_state(fsm.final_state) is True

for state in fsm.final_states:
assert fsm.is_final_state(state) is True


def test_regex_final_state():
"""Make sure that the FSM stays in the final state as we keep generating"""

class MockTokenizer:
vocabulary = {"`": 101, ".": 102, "\n": 103, "eos": 104}
special_tokens = {"eos"}
eos_token_id = 104

def convert_token_to_string(self, token):
return token

regex_str = r"`\n(\.\n)?`\n"
tokenizer = MockTokenizer()
fsm = RegexFSM(regex_str, tokenizer)

state = fsm.next_state(state=4, token_id=103)
assert state == 5
assert fsm.is_final_state(state)

state = fsm.next_state(state=5, token_id=103)
assert state == 5

assert fsm.is_final_state(-1)


def test_cfg():
Expand All @@ -79,8 +106,8 @@ def decode(self, token_ids):
tokenizer = MockTokenizer()
fsm = CFGFSM(cfg_str, tokenizer)

assert set(fsm.allowed_token_ids(state=fsm.first_state)) == {1, 3, 5}
state = fsm.next_state(state=fsm.first_state, token_id=1)
assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1, 3, 5}
state = fsm.next_state(state=fsm.start_state, token_id=1)
assert fsm.generation == "{"
assert not fsm.is_final_state(state)

Expand Down Expand Up @@ -130,8 +157,8 @@ def decode(self, token_ids):
tokenizer = MockTokenizer()
fsm = CFGFSM(cfg_str, tokenizer)

assert set(fsm.allowed_token_ids(state=fsm.first_state)) == {1}
state = fsm.next_state(state=fsm.first_state, token_id=1)
assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1}
state = fsm.next_state(state=fsm.start_state, token_id=1)
assert fsm.generation == "("
assert not fsm.is_final_state(state)

Expand Down Expand Up @@ -236,9 +263,9 @@ def decode(self, token_ids):
tokenizer = MockTokenizer()
fsm = CFGFSM(cfg_str, tokenizer)

assert set(fsm.allowed_token_ids(state=fsm.first_state)) == {1, 2}
assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1, 2}
assert fsm.reset_state # starting new regex
state = fsm.next_state(state=fsm.first_state, token_id=1)
state = fsm.next_state(state=fsm.start_state, token_id=1)
assert fsm.generation == "a"
assert not fsm.is_final_state(state)

Expand Down Expand Up @@ -279,8 +306,8 @@ def decode(self, token_ids):
tokenizer = MockTokenizer()
fsm = CFGFSM(cfg_str, tokenizer)

assert set(fsm.allowed_token_ids(state=fsm.first_state)) == {1, 3}
state = fsm.next_state(state=fsm.first_state, token_id=1)
assert set(fsm.allowed_token_ids(state=fsm.start_state)) == {1, 3}
state = fsm.next_state(state=fsm.start_state, token_id=1)
assert fsm.generation == "("
assert not fsm.is_final_state(state)

Expand Down
17 changes: 15 additions & 2 deletions tests/generate/test_integration_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,28 @@ class Spam(BaseModel):

def test_llamacpp_json_function(model):
model.model.reset()
prompt = "<|im_start|>user\nOutput arguments for the function<|im_end|>\n<|im_start|>assistant\n"
prompt = "<|im_start|>user\nOutput arguments for the function, array with 2 elements<|im_end|>\n<|im_start|>assistant\n"

def function(foo: int, bar: List[int]):
return foo + sum(bar)

rng = torch.Generator(device="cpu")
rng.manual_seed(0)
rng.manual_seed(10)
sequence = generate.json(model, function)(
prompt, max_tokens=100, temperature=0.0, rng=rng
)
assert isinstance(sequence, dict)
assert isinstance(function(**sequence), int)


def test_llamacpp_successive_choices(model):
model.model.reset()

choose = generate.regex(model, r"(one|two|three)")
assert choose("pick a numner") in ["one", "two", "three"]

cities = ["New York", "Paris", "San Francisco"]
city = generate.choice(model, cities)
assert city("pick a city") in cities

assert choose("a number") in ["one", "two", "three"]
Loading