Skip to content

Commit

Permalink
Revert "Incorporate Trie into fsm index calculation"
Browse files Browse the repository at this point in the history
This reverts commit 591ad2a.
  • Loading branch information
lapp0 committed Jun 4, 2024
1 parent c314cb8 commit 0ad30e0
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 276 deletions.
43 changes: 15 additions & 28 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
from numba.typed.typedobjectutils import _nonoptional
from tqdm import tqdm

from outlines.fsm.vocab_trie import VocabTrie

if TYPE_CHECKING:
from outlines.models.tokenizer import Tokenizer

Expand Down Expand Up @@ -651,38 +649,29 @@ def state_scan_tokens(
fsm_initial: int,
fsm_finals: Set[int],
vocabulary: List[Tuple[str, Sequence[int]]],
vocab_trie: VocabTrie,
token_trans_key_seqs: List[Sequence[int]],
start_state: int,
) -> Set[Tuple[int, int]]:
res = set()

# Initialize the stack with tokens having no prefixes
stack = numba.typed.List()
for token_transitions_seq in vocab_trie.get_children():
stack.append(token_transitions_seq)

# Process the tokens using the stack
while len(stack) > 0:
token_transition_seq = stack.pop()
for (token, token_ids), token_trans_key_seq in zip(
vocabulary, token_trans_key_seqs
):
state_seq = _walk_fsm(
fsm_transitions,
fsm_initial,
fsm_finals,
token_transition_seq,
token_trans_key_seq,
start_state,
False,
)

if state_seq is not None and len(state_seq) < len(token_transition_seq):
if state_seq is not None and len(state_seq) < len(token_trans_key_seq):
continue

for token_id in vocab_trie.get_token_ids(token_transition_seq):
for token_id in token_ids:
res.add((token_id, state_seq[-1]))

# Add successors to the stack
for new_token in vocab_trie.get_children(token_transition_seq):
stack.append(new_token)

return res


Expand Down Expand Up @@ -713,7 +702,7 @@ def get_token_transitions(


@numba.njit(cache=True, nogil=True)
def get_all_token_transitions(
def get_tokens_trans_keys(
alphabet_symbol_mapping: Dict[str, int],
alphabet_anything_value: int,
vocabulary: List[Tuple[str, Sequence[int]]],
Expand All @@ -740,20 +729,18 @@ def create_fsm_index_end_to_end(
seen: Set[int] = set()
next_states = {fsm_info.initial}

all_token_transitions = get_all_token_transitions(
fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
vocabulary,
)

vocab_trie = VocabTrie(all_token_transitions, vocabulary)

pbar = tqdm(
total=len(set(fsm_info.transitions.values()))
+ 1, # all transitions plus initial
desc="Compiling FSM index for all state transitions",
)

tokens_trans_key_seqs = get_tokens_trans_keys(
fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
vocabulary,
)

while next_states:
start_state = next_states.pop()

Expand All @@ -764,7 +751,7 @@ def create_fsm_index_end_to_end(
fsm_info.initial,
fsm_info.finals,
vocabulary,
vocab_trie,
tokens_trans_key_seqs,
start_state,
)

Expand Down
241 changes: 0 additions & 241 deletions outlines/fsm/vocab_trie.py

This file was deleted.

17 changes: 10 additions & 7 deletions tests/fsm/test_regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
create_fsm_index_end_to_end,
create_fsm_index_tokenizer,
fsm_union,
get_all_token_transitions,
get_sub_fsms_from_seq,
get_token_transitions,
get_tokens_trans_keys,
make_byte_level_better_fsm,
make_byte_level_fsm,
make_deterministic_fsm,
Expand All @@ -34,11 +33,15 @@ def merge_symbols(byte_hexs):


def token_str_to_trans_key(fsm, input_string):
return get_token_transitions(
vocabulary_nb = numba.typed.List.empty_list(
numba.types.Tuple((numba.types.unicode_type, numba.int64[:]))
)
vocabulary_nb.append((input_string, np.fromiter([1], dtype=np.dtype("int64"))))
return get_tokens_trans_keys(
fsm.fsm_info.alphabet_symbol_mapping,
fsm.fsm_info.alphabet_anything_value,
input_string,
)
vocabulary_nb,
)[0]


def walk_fsm_from_token_str(
Expand Down Expand Up @@ -595,7 +598,7 @@ def convert_token_to_string(self, token):
interegular_fsm = regex_pattern.to_fsm().reduce()
regex_fsm, _ = make_deterministic_fsm(interegular_fsm)
vocabulary, _ = reduced_vocabulary(tokenizer)
token_trans_keys = get_all_token_transitions(
token_trans_keys = get_tokens_trans_keys(
regex_fsm.fsm_info.alphabet_symbol_mapping,
regex_fsm.fsm_info.alphabet_anything_value,
vocabulary,
Expand Down Expand Up @@ -630,7 +633,7 @@ def convert_token_to_string(self, token):
interegular_fsm = regex_pattern.to_fsm().reduce()
regex_fsm, _ = make_deterministic_fsm(interegular_fsm)
vocabulary, _ = reduced_vocabulary(tokenizer)
token_trans_keys = get_all_token_transitions(
token_trans_keys = get_tokens_trans_keys(
regex_fsm.fsm_info.alphabet_symbol_mapping,
regex_fsm.fsm_info.alphabet_anything_value,
vocabulary,
Expand Down

0 comments on commit 0ad30e0

Please sign in to comment.