Skip to content

Commit

Permalink
pass token_transition_sequence to walk_fsm in parsing.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed May 31, 2024
1 parent 9a8a18e commit da2608d
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 24 deletions.
9 changes: 8 additions & 1 deletion outlines/fsm/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from outlines.fsm.regex import (
fsm_union,
get_sub_fsms_from_seq,
get_token_transitions,
make_deterministic_fsm,
walk_fsm,
)
Expand Down Expand Up @@ -569,9 +570,15 @@ def match(self, text, pos, last_fsm_state_seq: Optional[Tuple[int, ...]] = None)

text_part = text[start_pos:]

text_transitions = get_token_transitions(
self.fsm.fsm_info.alphabet_symbol_mapping,
self.fsm.fsm_info.alphabet_anything_value,
text_part,
)

state_seq = walk_fsm(
self.fsm,
text_part,
text_transitions,
start_state,
full_match=self.match_whole,
)
Expand Down
47 changes: 29 additions & 18 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,32 @@ def state_scan_tokens(
return res


@numba.njit(cache=True, nogil=True)
def get_token_transitions(
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
token_str: str,
) -> Sequence[int]:
trans_key_seq = []
i = 0
while i < len(token_str):
if token_str[i] == "\x00" and i != len(token_str) - 1:
symbol = token_str[i : i + 3]
i += 3
else:
symbol = token_str[i]
i += 1

trans_key_seq.append(
alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
)

trans_key_seq_array = np.empty(len(trans_key_seq), dtype=np.int64)
for j in range(len(trans_key_seq)):
trans_key_seq_array[j] = trans_key_seq[j]
return trans_key_seq_array


@numba.njit(cache=True, nogil=True)
def get_tokens_trans_keys(
alphabet_symbol_mapping: Dict[str, int],
Expand All @@ -683,24 +709,9 @@ def get_tokens_trans_keys(
) -> List[Sequence[int]]:
tokens_trans_keys = numba.typed.List.empty_list(numba.int64[:])
for token_str, _ in vocabulary:
trans_key_seq = []
i = 0
while i < len(token_str):
if token_str[i] == "\x00" and i != len(token_str) - 1:
symbol = token_str[i : i + 3]
i += 3
else:
symbol = token_str[i]
i += 1

trans_key_seq.append(
alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
)

trans_key_seq_array = np.empty(len(trans_key_seq), dtype=np.int64)
for j in range(len(trans_key_seq)):
trans_key_seq_array[j] = trans_key_seq[j]

trans_key_seq_array = get_token_transitions(
alphabet_symbol_mapping, alphabet_anything_value, token_str
)
tokens_trans_keys.append(trans_key_seq_array)

return tokens_trans_keys
Expand Down
14 changes: 9 additions & 5 deletions tests/fsm/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
from outlines.fsm.parsing import PartialLark, PartialPythonIndenter


def test_partial_parsing():
@pytest.fixture
def cleanup_lark_import():
yield
# Clean up lark.lark.LarkOptions._defaults
importlib.reload(lark.lark)


def test_partial_parsing(cleanup_lark_import):
lp = PartialLark.open_from_package(
"tests",
"partial_python.lark",
Expand Down Expand Up @@ -136,11 +143,8 @@ def test_partial_parsing():
assert len(parser_state.state_stack) == 4
assert parser_state.value_stack[-1].type == "LPAR"

# Clean up lark.lark.LarkOptions._defaults
importlib.reload(lark.lark)


def test_sequential_parse_example():
def test_sequential_parse_example(cleanup_lark_import):
input_tokens = [
"x ",
"= ",
Expand Down

0 comments on commit da2608d

Please sign in to comment.