Skip to content

Commit

Permalink
Use a persistent Tokenizer hash and add tokenizer to cache signature
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed May 22, 2024
1 parent 7863f8e commit 5d795cf
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 10 deletions.
6 changes: 4 additions & 2 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ class RegexGuide(Guide):

def __init__(self, regex_string: str, tokenizer):
@cache()
def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]:
def create_states_mapping(
regex_string: str, tokenizer: "Tokenizer"
) -> Tuple[dict, set, set]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
Expand Down Expand Up @@ -142,7 +144,7 @@ def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]:
self.states_to_token_maps,
self.empty_token_ids,
fsm_finals,
) = create_states_mapping(regex_string)
) = create_states_mapping(regex_string, tokenizer)
self.eos_token_id = tokenizer.eos_token_id
self.final_states = fsm_finals | {-1}

Expand Down
6 changes: 3 additions & 3 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

from datasets.fingerprint import Hasher

from outlines.models.tokenizer import Tokenizer

if TYPE_CHECKING:
Expand Down Expand Up @@ -109,9 +111,7 @@ def __eq__(self, other):
return NotImplemented

def __hash__(self):
from datasets.fingerprint import Hasher

return hash(Hasher.hash(self.tokenizer))
return int(Hasher.hash(self.tokenizer), 16)


class Transformers:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"referencing",
"jsonschema",
"requests",
"datasets",
]
dynamic = ["version"]

Expand All @@ -49,7 +50,6 @@ test = [
"diff-cover",
"accelerate",
"beartype<0.16.0",
"datasets",
"responses",
"llama-cpp-python",
"huggingface_hub",
Expand Down
27 changes: 26 additions & 1 deletion tests/generate/test_integration_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import outlines.models as models
from outlines.fsm.regex import reduced_vocabulary
from outlines.models.transformers import Transformers, TransformerTokenizer
from outlines.samplers import beam_search, multinomial
from outlines.samplers import beam_search, greedy, multinomial


def test_transformers_integration_text():
Expand Down Expand Up @@ -632,3 +632,28 @@ def test_transformers_use_existing_model_and_tokenizer():
model = Transformers(hf_model, hf_tokenizer)
sequence = generate.text(model)("Write a short sentence ", rng=rng)
assert isinstance(sequence, str)


def test_RegexGuide_caching():
regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
prompt = "What is the IP address of the Google DNS servers? "

model = models.transformers(
"hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"
)
generator = generate.regex(model, regex, sampler=greedy())

model_2 = models.transformers("hf-internal-testing/tiny-random-GPTJForCausalLM")
generator_2 = generate.regex(model_2, regex, sampler=greedy())

# These two different models and tokenizers should not have the same state
# mapping results
assert generator.fsm.states_to_token_maps != generator_2.fsm.states_to_token_maps

# Just for fun...
structured = generator(prompt, max_tokens=30)
structured_2 = generator_2(prompt, max_tokens=30)

assert re.fullmatch(regex, structured)
assert re.fullmatch(regex, structured_2)
assert structured != structured_2
14 changes: 11 additions & 3 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ def test_tokenizer_eq_hash():
tokenizer_hf = AutoTokenizer.from_pretrained("gpt2")

tokenizer = TransformerTokenizer(tokenizer_hf)
tokenizer2 = TransformerTokenizer(tokenizer_hf)
assert tokenizer == tokenizer2
assert hash(tokenizer) == hash(tokenizer2)
tokenizer_2 = TransformerTokenizer(tokenizer_hf)

assert tokenizer == tokenizer_2
assert hash(tokenizer) == hash(tokenizer_2)

tokenizer_hf_2 = AutoTokenizer.from_pretrained("gpt2")
tokenizer_hf_2.add_tokens(["test_token"])

tokenizer_3 = TransformerTokenizer(tokenizer_hf_2)
assert tokenizer != tokenizer_3
assert hash(tokenizer) != hash(tokenizer_3)

0 comments on commit 5d795cf

Please sign in to comment.