Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix null byte \x00 issue in byte level fsm resulting in KeyError in BetterFSM::FSMInfo #930

Merged
merged 8 commits into from
Jun 5, 2024
9 changes: 8 additions & 1 deletion outlines/fsm/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from outlines.fsm.regex import (
fsm_union,
get_sub_fsms_from_seq,
get_token_transitions,
make_deterministic_fsm,
walk_fsm,
)
Expand Down Expand Up @@ -569,9 +570,15 @@ def match(self, text, pos, last_fsm_state_seq: Optional[Tuple[int, ...]] = None)

text_part = text[start_pos:]

text_transitions = get_token_transitions(
self.fsm.fsm_info.alphabet_symbol_mapping,
self.fsm.fsm_info.alphabet_anything_value,
text_part,
)

state_seq = walk_fsm(
self.fsm,
text_part,
text_transitions,
start_state,
full_match=self.match_whole,
)
Expand Down
114 changes: 78 additions & 36 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,11 @@ def fsm_info(self):
((k, z) for k, v in self.trans_key_to_states.items() for z in v),
dtype=np.dtype("int64, int64"),
)
alphabet_symbol_mapping_items = np.fromiter(
(
it
for it in self.alphabet._symbol_mapping.items()
if it[0] != anything_else
),
dtype=np.dtype("U2, int64"),
)
alphabet_symbol_mapping_items = [
(k, v)
for k, v in self.alphabet._symbol_mapping.items()
if k != anything_else
]
nb_finals = np.fromiter(self.finals, dtype=np.dtype("int64"))
self.__dict__["_fsm_info"] = create_fsm_info(
self.initial,
Expand All @@ -110,7 +107,7 @@ def fsm_info(self):

nb_int_list_type = numba.types.ListType(numba.int64)
nb_int_pair_type = numba.types.UniTuple(numba.int64, 2)
nb_unichar_2_type = numba.types.UnicodeCharSeq(2)
nb_unicode_type = numba.types.unicode_type


@numba.njit(cache=True)
Expand All @@ -136,7 +133,7 @@ def create_fsm_info(

# use 2-char strings so that we can represent incomplete utf-8 sequences
# as 2-hex-digit pairs
alphabet_symbol_map = numba.typed.Dict.empty(nb_unichar_2_type, numba.int64)
alphabet_symbol_map = numba.typed.Dict.empty(nb_unicode_type, numba.int64)
for symbol_and_trans_key in alphabet_symbol_mapping_items:
alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1]

Expand Down Expand Up @@ -199,7 +196,7 @@ def transition_trie_setdefault(


def byte_symbol(byte: int) -> str:
return f"{byte:02X}" if byte >= 0x80 else chr(byte)
return f"\x00{byte:02X}" if byte >= 0x80 else chr(byte)


def make_byte_level_fsm(fsm: FSM, keep_utf8=False) -> FSM:
Expand Down Expand Up @@ -415,21 +412,19 @@ def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
@numba.njit(nogil=True, cache=True)
def _walk_fsm(
fsm_transitions: Dict[Tuple[int, int], int],
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
input_string: Sequence[str],
token_trans_key_seq: Sequence[int],
start_state: int,
full_match: bool = True,
) -> List[int]:
state = start_state
accepted_states: List[int] = numba.typed.List.empty_list(numba.int64)
last_final_idx: int = numba.uint64(0)

for i, symbol in enumerate(input_string):
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)

# Iterate over token transition key sequence. The transition key
# sequence represents the FSM traversal rules of the tokens symbols.
for i, trans_key in enumerate(token_trans_key_seq):
new_state = fsm_transitions.get((state, trans_key))

if new_state is None:
Expand All @@ -453,7 +448,7 @@ def _walk_fsm(

def walk_fsm(
fsm: BetterFSM,
input_string: Sequence[str],
token_trans_key_seq: Sequence[int],
start_state: int,
full_match: bool = True,
) -> List[int]:
Expand All @@ -463,13 +458,11 @@ def walk_fsm(
accepted_states: List[int] = []
last_final_idx: int = 0

alphabet_symbol_mapping = fsm.alphabet._symbol_mapping
alphabet_anything_value = fsm.alphabet.anything_value
fsm_transitions = fsm.flat_transition_map

for i, symbol in enumerate(input_string):
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)

# Iterate over token transition key sequence. The transition key
# sequence represents the FSM traversal rules of the tokens symbols.
for i, trans_key in enumerate(token_trans_key_seq):
new_state = fsm_transitions.get((state, trans_key))

if new_state is None:
Expand Down Expand Up @@ -655,24 +648,25 @@ def state_scan_tokens(
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
vocabulary: List[Tuple[Sequence[str], Sequence[int]]],
vocabulary: List[Tuple[str, Sequence[int]]],
token_trans_key_seqs: List[Sequence[int]],
start_state: int,
) -> Set[Tuple[int, int]]:
res = set()

for token, token_ids in vocabulary:
for (token, token_ids), token_trans_key_seq in zip(
vocabulary, token_trans_key_seqs
):
state_seq = _walk_fsm(
fsm_transitions,
alphabet_symbol_mapping,
alphabet_anything_value,
fsm_initial,
fsm_finals,
token,
token_trans_key_seq,
start_state,
False,
)

if state_seq is not None and len(state_seq) < len(token):
if state_seq is not None and len(state_seq) < len(token_trans_key_seq):
continue

for token_id in token_ids:
Expand All @@ -681,9 +675,51 @@ def state_scan_tokens(
return res


@numba.njit(cache=True, nogil=True)
def get_token_transitions(
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
token_str: str,
) -> Sequence[int]:
trans_key_seq = []
i = 0
while i < len(token_str):
if token_str[i] == "\x00" and i != len(token_str) - 1:
symbol = token_str[i : i + 3]
i += 3
else:
symbol = token_str[i]
i += 1

trans_key_seq.append(
alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
)

trans_key_seq_array = np.empty(len(trans_key_seq), dtype=np.int64)
for j in range(len(trans_key_seq)):
trans_key_seq_array[j] = trans_key_seq[j]
return trans_key_seq_array


@numba.njit(cache=True, nogil=True)
def get_tokens_trans_keys(
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
vocabulary: List[Tuple[str, Sequence[int]]],
) -> List[Sequence[int]]:
tokens_trans_keys = numba.typed.List.empty_list(numba.int64[:])
for token_str, _ in vocabulary:
trans_key_seq_array = get_token_transitions(
alphabet_symbol_mapping, alphabet_anything_value, token_str
)
tokens_trans_keys.append(trans_key_seq_array)

return tokens_trans_keys


def create_fsm_index_end_to_end(
fsm_info: FSMInfo,
vocabulary: List[Tuple[Sequence[str], Sequence[int]]],
vocabulary: List[Tuple[str, Sequence[int]]],
) -> Dict[int, Set[Tuple[int, int]]]:
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""

Expand All @@ -699,6 +735,12 @@ def create_fsm_index_end_to_end(
desc="Compiling FSM index for all state transitions",
)

tokens_trans_key_seqs = get_tokens_trans_keys(
fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
vocabulary,
)

while next_states:
start_state = next_states.pop()

Expand All @@ -709,6 +751,7 @@ def create_fsm_index_end_to_end(
fsm_info.initial,
fsm_info.finals,
vocabulary,
tokens_trans_key_seqs,
start_state,
)

Expand Down Expand Up @@ -771,7 +814,7 @@ def gpt2_unicode_to_bytes():
@lru_cache
def reduced_vocabulary(
tokenizer: "Tokenizer",
) -> Tuple[List[Tuple[Sequence[str], Sequence[int]]], Set[int]]:
) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]:
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
empty_token_ids = set()
vocabulary: Dict[Union[str, Tuple[str, ...]], List[int]] = {}
Expand Down Expand Up @@ -804,7 +847,7 @@ def reduced_vocabulary(
raise RuntimeError(
f"Cannot convert token `{token}` ({token_idx}) to bytes: {token_str}"
)
token_str = tuple(byte_symbol(b) for b in token_bytes)
token_str = "".join(byte_symbol(b) for b in token_bytes)

vocabulary.setdefault(token_str, []).append(token_idx)
else:
Expand All @@ -813,15 +856,14 @@ def reduced_vocabulary(
vocabulary_nb = numba.typed.List.empty_list(
numba.types.Tuple(
(
nb_unichar_2_type[:],
nb_unicode_type,
numba.int64[:],
)
)
)
for token_tuple, token_ids in vocabulary.items():
token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2"))
for token_str, token_ids in vocabulary.items():
token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64"))
vocabulary_nb.append((token_tuple_np, token_ids_np))
vocabulary_nb.append((token_str, token_ids_np))

return vocabulary_nb, empty_token_ids

Expand Down
14 changes: 9 additions & 5 deletions tests/fsm/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
from outlines.fsm.parsing import PartialLark, PartialPythonIndenter


def test_partial_parsing():
@pytest.fixture
def cleanup_lark_import():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change isn't directly related to the PR, but fixes an issue - if the tests in test_parsing.py fail, it cascades and causes a large number of unrelated tests to fail.

yield
# Clean up lark.lark.LarkOptions._defaults
importlib.reload(lark.lark)


def test_partial_parsing(cleanup_lark_import):
lp = PartialLark.open_from_package(
"tests",
"partial_python.lark",
Expand Down Expand Up @@ -136,11 +143,8 @@ def test_partial_parsing():
assert len(parser_state.state_stack) == 4
assert parser_state.value_stack[-1].type == "LPAR"

# Clean up lark.lark.LarkOptions._defaults
importlib.reload(lark.lark)


def test_sequential_parse_example():
def test_sequential_parse_example(cleanup_lark_import):
input_tokens = [
"x ",
"= ",
Expand Down
Loading
Loading