diff --git a/outlines/caching.py b/outlines/caching.py index 68207a0e4..52d66af74 100644 --- a/outlines/caching.py +++ b/outlines/caching.py @@ -1,15 +1,40 @@ import asyncio import functools -import hashlib import os from typing import Callable, Optional import cloudpickle -from diskcache import Cache +from diskcache import Cache, Disk +from diskcache.core import ENOVAL, UNKNOWN, args_to_key, full_name _caching_enabled = True +class CloudpickleDisk(Disk): + def __init__(self, directory, compress_level=1, **kwargs): + self.compress_level = compress_level + super().__init__(directory, **kwargs) + + def put(self, key): + data = cloudpickle.dumps(key) + return super().put(data) + + def get(self, key, raw): + data = super().get(key, raw) + return cloudpickle.loads(data) + + def store(self, value, read, key=UNKNOWN): + if not read: + value = cloudpickle.dumps(value) + return super().store(value, read, key=key) + + def fetch(self, mode, filename, value, read): + data = super().fetch(mode, filename, value, read) + if not read: + data = cloudpickle.loads(data) + return data + + @functools.lru_cache(1) def get_cache(): """Get the context object that contains previously-computed return values. @@ -26,7 +51,12 @@ def get_cache(): home_dir = os.path.expanduser("~") cache_dir = os.environ.get("OUTLINES_CACHE_DIR", f"{home_dir}/.cache/outlines") - memory = Cache(cache_dir, eviction_policy="none", cull_limit=0) + memory = Cache( + cache_dir, + eviction_policy="none", + cull_limit=0, + disk=CloudpickleDisk, + ) # ensure if version upgrade occurs, old cache is pruned if outlines_version != memory.get("__version__"): @@ -36,63 +66,72 @@ def get_cache(): return memory -def hash_arguments(*args, **kwargs) -> str: - """Create a hash out of the args and kwargs provided""" - result = hashlib.md5() - for item in list(args) + sorted(kwargs.items()): - result.update(cloudpickle.dumps(item)) - return result.hexdigest() - - -def cache(key_function: Optional[Callable] = None): +def cache(expire: Optional[float] = None, typed=False, ignore=()): """Caching decorator for memoizing function calls. + The cache key is created based on the values returned by the key_function callable if provided or based on the arguments of the decorated function directly otherwise + + This is based on `diskcache`'s `memoize`. + Parameters ---------- - key_function - A callable function used to generate a unique key for each function call. It's - called with the arguments of the decorated function as arguments + expire + Seconds until arguments expire. + typed + Cache different types separately. + ignore + Positional or keyword arguments to ignore. + Returns ------- - A decorator function that can be applied to other functions. + A decorator function that can be applied to other functions. """ def decorator(cached_function: Callable): memory = get_cache() - def wrapper(*args, **kwargs): - if not _caching_enabled: - return cached_function(*args, **kwargs) - if key_function: - key_args = key_function(*args, **kwargs) - cache_key = hash_arguments(*key_args) - else: - cache_key = hash_arguments(*args, **kwargs) - if cache_key in memory: - return memory[cache_key] - result = cached_function(*args, **kwargs) - memory[cache_key] = result - return result - - async def async_wrapper(*args, **kwargs): - if not _caching_enabled: - return await cached_function(*args, **kwargs) - if key_function: - key_args = key_function(*args, **kwargs) - cache_key = hash_arguments(*key_args) - else: - cache_key = hash_arguments(*args, **kwargs) - if cache_key in memory: - return memory[cache_key] - result = await cached_function(*args, **kwargs) - memory[cache_key] = result - return result + base = (full_name(cached_function),) if asyncio.iscoroutinefunction(cached_function): - return async_wrapper + + async def wrapper(*args, **kwargs): + if not _caching_enabled: + return await cached_function(*args, **kwargs) + + cache_key = wrapper.__cache_key__(*args, **kwargs) + result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True) + + if result is ENOVAL: + result = await cached_function(*args, **kwargs) + wrapper.__memory__.set(cache_key, result, expire, retry=True) + + return result + else: - return wrapper + + def wrapper(*args, **kwargs): + if not _caching_enabled: + return cached_function(*args, **kwargs) + + cache_key = wrapper.__cache_key__(*args, **kwargs) + result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True) + + if result is ENOVAL: + result = cached_function(*args, **kwargs) + wrapper.__memory__.set(cache_key, result, expire, retry=True) + + return result + + def __cache_key__(*args, **kwargs): + """Make key for cache given function arguments.""" + return args_to_key(base, args, kwargs, typed, ignore) + + wrapper.__cache_key__ = __cache_key__ # type: ignore + wrapper.__memory__ = memory # type: ignore + wrapper.__wrapped__ = cached_function # type: ignore + + return wrapper return decorator diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index 2833fce1a..d247db62b 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -105,44 +105,44 @@ def copy(self): return self +@cache() +def create_states_mapping( + regex_string: str, tokenizer: "Tokenizer" +) -> Tuple[dict, set, set]: + """Create the variables related to the mapping between states and tokens + The parameters of the function are used for caching purpose + """ + regex_pattern = interegular.parse_pattern(regex_string) + byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True) + regex_fsm, _ = make_deterministic_fsm(byte_fsm) + states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( + regex_fsm, tokenizer + ) + + # We make sure that it is possible to generate strings in the language + # of the regular expression with the tokens present in the model's + # vocabulary. + if not any( + regex_fsm.finals.intersection(v.values()) for v in states_to_token_maps.values() + ): + raise ValueError( + "The vocabulary does not allow us to build a sequence that matches the input regex" + ) + + return states_to_token_maps, empty_token_ids, regex_fsm.finals + + class RegexGuide(Guide): """Guide to generate text in the language of a regular expression.""" initial_state = 0 def __init__(self, regex_string: str, tokenizer): - @cache() - def create_states_mapping(regex_string: str) -> Tuple[dict, set, set]: - """Create the variables related to the mapping between states and tokens - The parameters of the function are used for caching purpose - """ - regex_pattern = interegular.parse_pattern(regex_string) - byte_fsm = make_byte_level_fsm( - regex_pattern.to_fsm().reduce(), keep_utf8=True - ) - regex_fsm, _ = make_deterministic_fsm(byte_fsm) - states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( - regex_fsm, tokenizer - ) - - # We make sure that it is possible to generate strings in the language - # of the regular expression with the tokens present in the model's - # vocabulary. - if not any( - regex_fsm.finals.intersection(v.values()) - for v in states_to_token_maps.values() - ): - raise ValueError( - "The vocabulary does not allow us to build a sequence that matches the input regex" - ) - - return states_to_token_maps, empty_token_ids, regex_fsm.finals - ( self.states_to_token_maps, self.empty_token_ids, fsm_finals, - ) = create_states_mapping(regex_string) + ) = create_states_mapping(regex_string, tokenizer) self.eos_token_id = tokenizer.eos_token_id self.final_states = fsm_finals | {-1} diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 3bc59412e..fae9b8e74 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -1,5 +1,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from datasets.fingerprint import Hasher + from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: @@ -109,10 +111,15 @@ def __eq__(self, other): return NotImplemented def __hash__(self): - from datasets.fingerprint import Hasher - return hash(Hasher.hash(self.tokenizer)) + def __getstate__(self): + state = {"tokenizer": self.tokenizer} + return state + + def __setstate__(self, state): + self.__init__(state["tokenizer"]) + class Transformers: """Represents a `transformers` model.""" diff --git a/pyproject.toml b/pyproject.toml index b18036ffc..0b310c44b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,8 @@ dependencies = [ "referencing", "jsonschema", "requests", - "tqdm" + "tqdm", + "datasets", ] dynamic = ["version"] @@ -50,7 +51,6 @@ test = [ "diff-cover", "accelerate", "beartype<0.16.0", - "datasets", "responses", "llama-cpp-python", "huggingface_hub", diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index 38525a076..cee3ca312 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -1,6 +1,7 @@ import datetime import re from enum import Enum +from importlib import reload from typing import List, Union import pytest @@ -11,7 +12,28 @@ import outlines.models as models from outlines.fsm.regex import reduced_vocabulary from outlines.models.transformers import Transformers, TransformerTokenizer -from outlines.samplers import beam_search, multinomial +from outlines.samplers import beam_search, greedy, multinomial + + +@pytest.fixture +def temp_cache_dir(): + import os + import tempfile + + import outlines.caching + import outlines.fsm.guide + + with tempfile.TemporaryDirectory() as tempdir: + os.environ["OUTLINES_CACHE_DIR"] = tempdir + outlines.caching.get_cache.cache_clear() + reload(outlines) + reload(outlines.fsm.guide) + cache_status = outlines.caching._caching_enabled + try: + outlines.caching._caching_enabled = True + yield + finally: + outlines.caching._caching_enabled = cache_status def test_transformers_integration_text(): @@ -632,3 +654,47 @@ def test_transformers_use_existing_model_and_tokenizer(): model = Transformers(hf_model, hf_tokenizer) sequence = generate.text(model)("Write a short sentence ", rng=rng) assert isinstance(sequence, str) + + +def test_RegexGuide_caching(temp_cache_dir): + import outlines.caching + from outlines.fsm.guide import create_states_mapping + + assert outlines.caching._caching_enabled + + regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" + prompt = "What is the IP address of the Google DNS servers? " + + cache = outlines.caching.get_cache() + + # Returns (hits, misses) + _ = cache.stats(enable=True) + assert cache.statistics + + assert create_states_mapping.__memory__ is cache + + model = models.transformers( + "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM" + ) + generator = generate.regex(model, regex, sampler=greedy()) + assert cache.stats() == (0, 1) + + model_2 = models.transformers("hf-internal-testing/tiny-random-GPTJForCausalLM") + generator_2 = generate.regex(model_2, regex, sampler=greedy()) + assert cache.stats() == (0, 2) + + # These two different models and tokenizers should not have the same state + # mapping results + assert generator.fsm.states_to_token_maps != generator_2.fsm.states_to_token_maps + + generator_3 = generate.regex(model_2, regex, sampler=greedy()) + assert cache.stats() == (1, 2) + assert generator_2.fsm.states_to_token_maps == generator_3.fsm.states_to_token_maps + + # Just for fun... + structured = generator(prompt, max_tokens=30) + structured_2 = generator_2(prompt, max_tokens=30) + + assert re.fullmatch(regex, structured) + assert re.fullmatch(regex, structured_2) + assert structured != structured_2 diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index b4e410096..f4596a2df 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -107,6 +107,14 @@ def test_tokenizer_eq_hash(): tokenizer_hf = AutoTokenizer.from_pretrained("gpt2") tokenizer = TransformerTokenizer(tokenizer_hf) - tokenizer2 = TransformerTokenizer(tokenizer_hf) - assert tokenizer == tokenizer2 - assert hash(tokenizer) == hash(tokenizer2) + tokenizer_2 = TransformerTokenizer(tokenizer_hf) + + assert tokenizer == tokenizer_2 + assert hash(tokenizer) == hash(tokenizer_2) + + tokenizer_hf_2 = AutoTokenizer.from_pretrained("gpt2") + tokenizer_hf_2.add_tokens(["test_token"]) + + tokenizer_3 = TransformerTokenizer(tokenizer_hf_2) + assert tokenizer != tokenizer_3 + assert hash(tokenizer) != hash(tokenizer_3)