Skip to content

Commit

Permalink
Use a persistent Tokenizer hash and add tokenizer to cache signature
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed May 21, 2024
1 parent 7863f8e commit 8136133
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 9 deletions.
6 changes: 4 additions & 2 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ class RegexGuide(Guide):

def __init__(self, regex_string: str, tokenizer):
@cache()
def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]:
def create_states_mapping(
regex_string: str, tokenizer: "Tokenizer"
) -> Tuple[dict, set, set]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
Expand Down Expand Up @@ -142,7 +144,7 @@ def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]:
self.states_to_token_maps,
self.empty_token_ids,
fsm_finals,
) = create_states_mapping(regex_string)
) = create_states_mapping(regex_string, tokenizer)
self.eos_token_id = tokenizer.eos_token_id
self.final_states = fsm_finals | {-1}

Expand Down
6 changes: 3 additions & 3 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

from datasets.fingerprint import Hasher

from outlines.models.tokenizer import Tokenizer

if TYPE_CHECKING:
Expand Down Expand Up @@ -109,9 +111,7 @@ def __eq__(self, other):
return NotImplemented

def __hash__(self):
from datasets.fingerprint import Hasher

return hash(Hasher.hash(self.tokenizer))
return int(Hasher.hash(self.tokenizer), 16)


class Transformers:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"referencing",
"jsonschema",
"requests",
"datasets",
]
dynamic = ["version"]

Expand All @@ -49,7 +50,6 @@ test = [
"diff-cover",
"accelerate",
"beartype<0.16.0",
"datasets",
"responses",
"llama-cpp-python",
"huggingface_hub",
Expand Down
14 changes: 11 additions & 3 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ def test_tokenizer_eq_hash():
tokenizer_hf = AutoTokenizer.from_pretrained("gpt2")

tokenizer = TransformerTokenizer(tokenizer_hf)
tokenizer2 = TransformerTokenizer(tokenizer_hf)
assert tokenizer == tokenizer2
assert hash(tokenizer) == hash(tokenizer2)
tokenizer_2 = TransformerTokenizer(tokenizer_hf)

assert tokenizer == tokenizer_2
assert hash(tokenizer) == hash(tokenizer_2)

tokenizer_hf_2 = AutoTokenizer.from_pretrained("gpt2")
tokenizer_hf_2.add_tokens(["test_token"])

tokenizer_3 = TransformerTokenizer(tokenizer_hf_2)
assert tokenizer != tokenizer_3
assert hash(tokenizer) != hash(tokenizer_3)

0 comments on commit 8136133

Please sign in to comment.