Skip to content

Commit

Permalink
ensure byte fsm unicode_type compatibility by prefixing hex-bytes wit…
Browse files Browse the repository at this point in the history
…h \x00
  • Loading branch information
lapp0 committed May 30, 2024
1 parent 4a5ef55 commit 83c4d3a
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 28 deletions.
54 changes: 41 additions & 13 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def transition_trie_setdefault(


def byte_symbol(byte: int) -> str:
return f"{byte:02X}" if byte >= 0x80 else chr(byte)
return f"\x00{byte:02X}" if byte >= 0x80 else chr(byte)


def make_byte_level_fsm(fsm: FSM, keep_utf8=False) -> FSM:
Expand Down Expand Up @@ -416,15 +416,29 @@ def _walk_fsm(
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
input_string: Sequence[str],
input_string: str,
start_state: int,
full_match: bool = True,
) -> List[int]:
state = start_state
accepted_states: List[int] = numba.typed.List.empty_list(numba.int64)
last_final_idx: int = numba.uint64(0)

for i, symbol in enumerate(input_string):
# Iterate over symbols (characters and null-prefixed two-hex-character bytes)
# By default, each symbol is a unicode character
# Except, if the character, input_string[i] == '\x00', then the next two
# in input_string characters are a hex representation of the byte
i = 0
while i < len(input_string):
# if null-byte prefixed its a hex representation
# unless its the last character, then its a trailing null byte symbol
if input_string[i] == "\x00" and i != len(input_string) - 1:
symbol = input_string[i : i + 3]
i += 3
else:
symbol = input_string[i]
i += 1

trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)

new_state = fsm_transitions.get((state, trans_key))
Expand All @@ -438,19 +452,19 @@ def _walk_fsm(
state = new_state

if state in fsm_finals:
last_final_idx = numba.uint64(i + 1)
last_final_idx = numba.uint64(i)

accepted_states.append(_nonoptional(state))

if full_match and last_final_idx - 1 != i:
if full_match and last_final_idx != i:
return numba.typed.List.empty_list(numba.int64)

return accepted_states


def walk_fsm(
fsm: BetterFSM,
input_string: Sequence[str],
input_string: str,
start_state: int,
full_match: bool = True,
) -> List[int]:
Expand All @@ -464,7 +478,17 @@ def walk_fsm(
alphabet_anything_value = fsm.alphabet.anything_value
fsm_transitions = fsm.flat_transition_map

for i, symbol in enumerate(input_string):
# See _walk_fsm() explanation of symbol iteration
i = 0
while i < len(input_string):
# if null-byte prefixed its a hex representation
# unless the input string itself is a null byte, then symbol is a lone null-byte
if input_string[i] == "\x00" and input_string != "\x00":
symbol = input_string[i : i + 3]
i += 3
else:
symbol = input_string[i]
i += 1
trans_key = alphabet_symbol_mapping.get(symbol, alphabet_anything_value)

new_state = fsm_transitions.get((state, trans_key))
Expand All @@ -478,11 +502,11 @@ def walk_fsm(
state = new_state

if state in fsm_finals:
last_final_idx = i + 1
last_final_idx = i

accepted_states.append(state)

if full_match and last_final_idx - 1 != i:
if full_match and last_final_idx != i:
return []

return accepted_states
Expand Down Expand Up @@ -652,7 +676,7 @@ def state_scan_tokens(
alphabet_anything_value: int,
fsm_initial: int,
fsm_finals: Set[int],
vocabulary: List[Tuple[Sequence[str], Sequence[int]]],
vocabulary: List[Tuple[str, Sequence[int]]],
start_state: int,
) -> Set[Tuple[int, int]]:
res = set()
Expand All @@ -669,7 +693,11 @@ def state_scan_tokens(
False,
)

if state_seq is not None and len(state_seq) < len(token):
if token == "\x00":
token_length = 1
else:
token_length = len(token) - 2 * token.count("\x00")
if state_seq is not None and len(state_seq) < token_length:
continue

for token_id in token_ids:
Expand All @@ -680,7 +708,7 @@ def state_scan_tokens(

def create_fsm_index_end_to_end(
fsm_info: FSMInfo,
vocabulary: List[Tuple[Sequence[str], Sequence[int]]],
vocabulary: List[Tuple[str, Sequence[int]]],
) -> Dict[int, Set[Tuple[int, int]]]:
"""Create an FSM state-to-vocabulary map/index through end-to-end token parsing."""

Expand Down Expand Up @@ -768,7 +796,7 @@ def gpt2_unicode_to_bytes():
@lru_cache
def reduced_vocabulary(
tokenizer: "Tokenizer",
) -> Tuple[List[Tuple[Sequence[str], Sequence[int]]], Set[int]]:
) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]:
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
empty_token_ids = set()
vocabulary: Dict[Union[str, Tuple[str, ...]], List[int]] = {}
Expand Down
53 changes: 38 additions & 15 deletions tests/fsm/test_regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def identity(s):


def to_bytes(s):
return [chr(b) if b < 0x80 else f"{b:02X}" for b in s.encode("utf-8")]
return [chr(b) if b < 0x80 else f"\x00{b:02X}" for b in s.encode("utf-8")]


def walk_fsm_numba(
Expand Down Expand Up @@ -115,19 +115,27 @@ def test_walk_fsm_multi_bytes(function, transform):
str_regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
regex_fsm = make_byte_level_better_fsm(str_regex_fsm, keep_utf8=True)

res = tuple(function(regex_fsm, transform("😂"), regex_fsm.initial, full_match=True))
res = tuple(
function(regex_fsm, "".join(transform("😂")), regex_fsm.initial, full_match=True)
)
assert res[-1:] == (1,)

res = tuple(
function(regex_fsm, transform("😂😂"), regex_fsm.initial, full_match=False)
function(
regex_fsm, "".join(transform("😂😂")), regex_fsm.initial, full_match=False
)
)
assert res[-1:] == (1,)

res = tuple(function(regex_fsm, transform("!"), regex_fsm.initial, full_match=True))
res = tuple(
function(regex_fsm, "".join(transform("!")), regex_fsm.initial, full_match=True)
)
assert res == tuple()

res = tuple(
function(regex_fsm, transform("😂😂"), regex_fsm.initial, full_match=True)
function(
regex_fsm, "".join(transform("😂😂")), regex_fsm.initial, full_match=True
)
)
assert res == tuple()

Expand Down Expand Up @@ -304,15 +312,15 @@ def test_create_fsm_index_end_to_end():
vocabulary_nb = numba.typed.List.empty_list(
numba.types.Tuple(
(
numba.types.UnicodeCharSeq(2)[:],
numba.types.unicode_type,
numba.int64[:],
)
)
)
for token_tuple, token_ids in vocabulary.items():
token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2"))
token = "".join(token_tuple)
token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64"))
vocabulary_nb.append((token_tuple_np, token_ids_np))
vocabulary_nb.append((token, token_ids_np))

res = create_fsm_index_end_to_end(regex_fsm.fsm_info, vocabulary_nb)

Expand All @@ -326,28 +334,34 @@ def test_create_fsm_index_end_to_end_multi_byte():
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
byte_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True)

merge_symbols = lambda byte_hexs: "".join(
["" + b if len(b) == 2 else b for b in byte_hexs]
)

vocabulary = {
"blah": numba.typed.List([0]),
"😈a": numba.typed.List([1]),
"😇": numba.typed.List([2]),
"😍": numba.typed.List([3]),
("F0", "9F", "98", "8D"): numba.typed.List([4]), # '😍'
merge_symbols(("F0", "9F", "98", "8D")): numba.typed.List([4]), # '😍'
" 😍": numba.typed.List([5]),
(" ", "F0", "9F", "98", "8D"): numba.typed.List([6]), # ' 😍'
(" ", "F0", "9F", "98"): numba.typed.List([7]), # ' 😍' incomplete
merge_symbols((" ", "F0", "9F", "98", "8D")): numba.typed.List([6]), # ' 😍'
merge_symbols((" ", "F0", "9F", "98")): numba.typed.List(
[7]
), # ' 😍' incomplete
"<EOS>": numba.typed.List([8]),
}

vocabulary_nb = numba.typed.List.empty_list(
numba.types.Tuple(
(
numba.types.UnicodeCharSeq(2)[:],
numba.types.unicode_type,
numba.int64[:],
)
)
)
for token_tuple, token_ids in vocabulary.items():
token_tuple_np = np.fromiter(token_tuple, dtype=np.dtype("U2"))
token_tuple_np = merge_symbols(token_tuple)
token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64"))
vocabulary_nb.append((token_tuple_np, token_ids_np))

Expand All @@ -356,7 +370,16 @@ def test_create_fsm_index_end_to_end_multi_byte():
assert res == {0: {(5, 3), (6, 3), (7, 7), (2, 2)}, 3: {(2, 3), (3, 3), (4, 3)}}


def test_create_fsm_index_tokenizer():
@pytest.mark.parametrize(
"hf_tokenizer_uri",
[
"gpt2",
"microsoft/phi-2",
"Qwen/Qwen1.5-0.5B-Chat",
"NousResearch/Hermes-2-Pro-Llama-3-8B",
],
)
def test_create_fsm_index_tokenizer(hf_tokenizer_uri):
# The combined regular expressions of a lexer state in a Python grammar
regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~"

Expand All @@ -371,7 +394,7 @@ def test_create_fsm_index_tokenizer():
num_bytes_fsm_states = len(bytes_fsm.states)
assert num_bytes_fsm_states == 235

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_uri)
tokenizer = TransformerTokenizer(tokenizer)

states_to_token_subsets, empty_token_ids = create_fsm_index_tokenizer(
Expand Down

0 comments on commit 83c4d3a

Please sign in to comment.