Skip to content

Commit

Permalink
index token -> transition key sequence for efficient _walk_fsm
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed May 30, 2024
1 parent 83c4d3a commit 56ea957
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 79 deletions.
105 changes: 56 additions & 49 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,35 +412,19 @@ def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
@numba.njit(nogil=True, cache=True)
def _walk_fsm(
fsm_transitions: Dict[Tuple[int, int], int],
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
input_string: str,
token_trans_key_seq: Sequence[int],
start_state: int,
full_match: bool = True,
) -> List[int]:
state = start_state
accepted_states: List[int] = numba.typed.List.empty_list(numba.int64)
last_final_idx: int = numba.uint64(0)

# Iterate over symbols (characters and null-prefixed two-hex-character bytes)
# By default, each symbol is a unicode character
# Except, if the character, input_string[i] == '\x00', then the next two
# in input_string characters are a hex representation of the byte
i = 0
while i < len(input_string):
# if null-byte prefixed its a hex representation
# unless its the last character, then its a trailing null byte symbol
if input_string[i] == "\x00" and i != len(input_string) - 1:
symbol = input_string[i : i + 3]
i += 3
else:
symbol = input_string[i]
i += 1

trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)

# Iterate over token transition key sequence. The transition key
# sequence represents the FSM traversal rules of the tokens symbols.
for i, trans_key in enumerate(token_trans_key_seq):
new_state = fsm_transitions.get((state, trans_key))

if new_state is None:
Expand All @@ -452,19 +436,19 @@ def _walk_fsm(
state = new_state

if state in fsm_finals:
last_final_idx = numba.uint64(i)
last_final_idx = numba.uint64(i + 1)

accepted_states.append(_nonoptional(state))

if full_match and last_final_idx != i:
if full_match and last_final_idx - 1 != i:
return numba.typed.List.empty_list(numba.int64)

return accepted_states


def walk_fsm(
fsm: BetterFSM,
input_string: str,
token_trans_key_seq: Sequence[int],
start_state: int,
full_match: bool = True,
) -> List[int]:
Expand All @@ -474,23 +458,11 @@ def walk_fsm(
accepted_states: List[int] = []
last_final_idx: int = 0

alphabet_symbol_mapping = fsm.alphabet._symbol_mapping
alphabet_anything_value = fsm.alphabet.anything_value
fsm_transitions = fsm.flat_transition_map

# See _walk_fsm() explanation of symbol iteration
i = 0
while i < len(input_string):
# if null-byte prefixed its a hex representation
# unless the input string itself is a null byte, then symbol is a lone null-byte
if input_string[i] == "\x00" and input_string != "\x00":
symbol = input_string[i : i + 3]
i += 3
else:
symbol = input_string[i]
i += 1
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)

# Iterate over token transition key sequence. The transition key
# sequence represents the FSM traversal rules of the tokens symbols.
for i, trans_key in enumerate(token_trans_key_seq):
new_state = fsm_transitions.get((state, trans_key))

if new_state is None:
Expand All @@ -502,11 +474,11 @@ def walk_fsm(
state = new_state

if state in fsm_finals:
last_final_idx = i
last_final_idx = i + 1

accepted_states.append(state)

if full_match and last_final_idx != i:
if full_match and last_final_idx - 1 != i:
return []

return accepted_states
Expand Down Expand Up @@ -677,27 +649,24 @@ def state_scan_tokens(
fsm_initial: int,
fsm_finals: Set[int],
vocabulary: List[Tuple[str, Sequence[int]]],
token_trans_key_seqs: List[Sequence[int]],
start_state: int,
) -> Set[Tuple[int, int]]:
res = set()

for token, token_ids in vocabulary:
for (token, token_ids), token_trans_key_seq in zip(
vocabulary, token_trans_key_seqs
):
state_seq = _walk_fsm(
fsm_transitions,
alphabet_symbol_mapping,
alphabet_anything_value,
fsm_initial,
fsm_finals,
token,
token_trans_key_seq,
start_state,
False,
)

if token == "\x00":
token_length = 1
else:
token_length = len(token) - 2 * token.count("\x00")
if state_seq is not None and len(state_seq) < token_length:
if state_seq is not None and len(state_seq) < len(token_trans_key_seq):
continue

for token_id in token_ids:
Expand All @@ -706,6 +675,37 @@ def state_scan_tokens(
return res


@numba.njit(cache=True, nogil=True)
def get_tokens_trans_keys(
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
vocabulary: List[Tuple[str, Sequence[int]]],
) -> List[Sequence[int]]:
tokens_trans_keys = numba.typed.List.empty_list(numba.int64[:])
for token_str, _ in vocabulary:
trans_key_seq = []
i = 0
while i < len(token_str):
if token_str[i] == "\x00" and i != len(token_str) - 1:
symbol = token_str[i : i + 3]
i += 3
else:
symbol = token_str[i]
i += 1

trans_key_seq.append(
alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
)

trans_key_seq_array = np.empty(len(trans_key_seq), dtype=np.int64)
for j in range(len(trans_key_seq)):
trans_key_seq_array[j] = trans_key_seq[j]

tokens_trans_keys.append(trans_key_seq_array)

return tokens_trans_keys


def create_fsm_index_end_to_end(
fsm_info: FSMInfo,
vocabulary: List[Tuple[str, Sequence[int]]],
Expand All @@ -724,6 +724,12 @@ def create_fsm_index_end_to_end(
desc="Compiling FSM index for all state transitions",
)

tokens_trans_key_seqs = get_tokens_trans_keys(
fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
vocabulary,
)

while next_states:
start_state = next_states.pop()

Expand All @@ -734,6 +740,7 @@ def create_fsm_index_end_to_end(
fsm_info.initial,
fsm_info.finals,
vocabulary,
tokens_trans_key_seqs,
start_state,
)

Expand Down
Loading

0 comments on commit 56ea957

Please sign in to comment.