From c247c319e179164faed9ba77a3ea695d4181407f Mon Sep 17 00:00:00 2001 From: pgrundmann Date: Tue, 11 Jun 2024 14:32:20 +0200 Subject: [PATCH] WIP --- build/lib/outlines/__init__.py | 19 + build/lib/outlines/base.py | 288 ++++++ build/lib/outlines/caching.py | 179 ++++ build/lib/outlines/fsm/__init__.py | 0 build/lib/outlines/fsm/fsm.py | 69 ++ build/lib/outlines/fsm/guide.py | 426 ++++++++ build/lib/outlines/fsm/json_schema.py | 436 +++++++++ build/lib/outlines/fsm/parsing.py | 870 +++++++++++++++++ build/lib/outlines/fsm/regex.py | 922 ++++++++++++++++++ build/lib/outlines/fsm/types.py | 81 ++ build/lib/outlines/fsm/vocab_trie.py | 241 +++++ build/lib/outlines/function.py | 117 +++ build/lib/outlines/generate/__init__.py | 8 + build/lib/outlines/generate/api.py | 531 ++++++++++ build/lib/outlines/generate/cfg.py | 64 ++ build/lib/outlines/generate/choice.py | 36 + build/lib/outlines/generate/format.py | 45 + build/lib/outlines/generate/fsm.py | 14 + build/lib/outlines/generate/generator.py | 312 ++++++ build/lib/outlines/generate/json.py | 78 ++ build/lib/outlines/generate/regex.py | 73 ++ build/lib/outlines/generate/text.py | 57 ++ build/lib/outlines/grammars.py | 14 + build/lib/outlines/grammars/arithmetic.lark | 18 + build/lib/outlines/grammars/common.lark | 80 ++ build/lib/outlines/grammars/json.lark | 19 + build/lib/outlines/integrations/__init__.py | 1 + build/lib/outlines/integrations/llamacpp.py | 191 ++++ .../lib/outlines/integrations/transformers.py | 159 +++ build/lib/outlines/integrations/utils.py | 103 ++ build/lib/outlines/integrations/vllm.py | 177 ++++ build/lib/outlines/models/__init__.py | 17 + build/lib/outlines/models/exllamav2.py | 232 +++++ build/lib/outlines/models/llamacpp.py | 391 ++++++++ build/lib/outlines/models/mamba.py | 61 ++ build/lib/outlines/models/openai.py | 467 +++++++++ build/lib/outlines/models/tokenizer.py | 31 + build/lib/outlines/models/transformers.py | 236 +++++ build/lib/outlines/models/vllm.py | 159 +++ build/lib/outlines/prompts.py | 338 +++++++ build/lib/outlines/py.typed | 0 build/lib/outlines/samplers.py | 324 ++++++ build/lib/outlines/serve/__init__.py | 0 build/lib/outlines/serve/serve.py | 136 +++ build/lib/outlines/serve/vllm.py | 4 + build/lib/outlines/types/__init__.py | 4 + build/lib/outlines/types/airports.py | 16 + build/lib/outlines/types/countries.py | 24 + build/lib/outlines/types/email.py | 11 + build/lib/outlines/types/isbn.py | 12 + build/lib/outlines/types/locales.py | 21 + build/lib/outlines/types/phone_numbers.py | 16 + build/lib/outlines/types/zip_codes.py | 13 + outlines/fsm/json_schema.py | 2 +- outlines/generate/json.py | 2 +- outlines/integrations/vllm.py | 33 +- 56 files changed, 8167 insertions(+), 11 deletions(-) create mode 100644 build/lib/outlines/__init__.py create mode 100644 build/lib/outlines/base.py create mode 100644 build/lib/outlines/caching.py create mode 100644 build/lib/outlines/fsm/__init__.py create mode 100644 build/lib/outlines/fsm/fsm.py create mode 100644 build/lib/outlines/fsm/guide.py create mode 100644 build/lib/outlines/fsm/json_schema.py create mode 100644 build/lib/outlines/fsm/parsing.py create mode 100644 build/lib/outlines/fsm/regex.py create mode 100644 build/lib/outlines/fsm/types.py create mode 100644 build/lib/outlines/fsm/vocab_trie.py create mode 100644 build/lib/outlines/function.py create mode 100644 build/lib/outlines/generate/__init__.py create mode 100644 build/lib/outlines/generate/api.py create mode 100644 build/lib/outlines/generate/cfg.py create mode 100644 build/lib/outlines/generate/choice.py create mode 100644 build/lib/outlines/generate/format.py create mode 100644 build/lib/outlines/generate/fsm.py create mode 100644 build/lib/outlines/generate/generator.py create mode 100644 build/lib/outlines/generate/json.py create mode 100644 build/lib/outlines/generate/regex.py create mode 100644 build/lib/outlines/generate/text.py create mode 100644 build/lib/outlines/grammars.py create mode 100644 build/lib/outlines/grammars/arithmetic.lark create mode 100644 build/lib/outlines/grammars/common.lark create mode 100644 build/lib/outlines/grammars/json.lark create mode 100644 build/lib/outlines/integrations/__init__.py create mode 100644 build/lib/outlines/integrations/llamacpp.py create mode 100644 build/lib/outlines/integrations/transformers.py create mode 100644 build/lib/outlines/integrations/utils.py create mode 100644 build/lib/outlines/integrations/vllm.py create mode 100644 build/lib/outlines/models/__init__.py create mode 100644 build/lib/outlines/models/exllamav2.py create mode 100644 build/lib/outlines/models/llamacpp.py create mode 100644 build/lib/outlines/models/mamba.py create mode 100644 build/lib/outlines/models/openai.py create mode 100644 build/lib/outlines/models/tokenizer.py create mode 100644 build/lib/outlines/models/transformers.py create mode 100644 build/lib/outlines/models/vllm.py create mode 100644 build/lib/outlines/prompts.py create mode 100644 build/lib/outlines/py.typed create mode 100644 build/lib/outlines/samplers.py create mode 100644 build/lib/outlines/serve/__init__.py create mode 100644 build/lib/outlines/serve/serve.py create mode 100644 build/lib/outlines/serve/vllm.py create mode 100644 build/lib/outlines/types/__init__.py create mode 100644 build/lib/outlines/types/airports.py create mode 100644 build/lib/outlines/types/countries.py create mode 100644 build/lib/outlines/types/email.py create mode 100644 build/lib/outlines/types/isbn.py create mode 100644 build/lib/outlines/types/locales.py create mode 100644 build/lib/outlines/types/phone_numbers.py create mode 100644 build/lib/outlines/types/zip_codes.py diff --git a/build/lib/outlines/__init__.py b/build/lib/outlines/__init__.py new file mode 100644 index 000000000..3eb6a2f94 --- /dev/null +++ b/build/lib/outlines/__init__.py @@ -0,0 +1,19 @@ +"""Outlines is a Generative Model Programming Framework.""" +import outlines.generate +import outlines.grammars +import outlines.models +import outlines.types +from outlines.base import vectorize +from outlines.caching import clear_cache, disable_cache, get_cache +from outlines.function import Function +from outlines.prompts import prompt + +__all__ = [ + "clear_cache", + "disable_cache", + "get_cache", + "Function", + "prompt", + "vectorize", + "grammars", +] diff --git a/build/lib/outlines/base.py b/build/lib/outlines/base.py new file mode 100644 index 000000000..4de8ccf5a --- /dev/null +++ b/build/lib/outlines/base.py @@ -0,0 +1,288 @@ +import asyncio +import builtins +import functools +import inspect +from typing import Callable, Optional + +import numpy as np +from numpy.lib.function_base import ( + _calculate_shapes, + _parse_gufunc_signature, + _parse_input_dimensions, + _update_dim_sizes, +) + +# Allow nested loops for running in notebook. We don't enable it globally as it +# may interfere with other libraries that use asyncio. +if hasattr(builtins, "__IPYTHON__"): + try: + import nest_asyncio + + nest_asyncio.apply() + except ImportError: + print( + "Couldn't patch nest_asyncio because it's not installed. Running in the notebook might be have issues" + ) + + +class vectorize: + """Returns an object that acts like a function but takes arrays as an input. + + The vectorized function evaluates `func` over successive tuples of the input + chararrays and returns a single NumPy chararrays or a tuple of NumPy chararrays. + + Its behavior is similar to NumPy's `vectorize` for Python functions: the function + being vectorized is executed in a `for` loop. Coroutines, however, are executed + concurrently. + + Part of the code was adapted from `numpy.lib.function_base`. + + """ + + def __init__(self, func: Callable, signature: Optional[str] = None): + self.func = func + self.signature = signature + self.is_coroutine_fn = inspect.iscoroutinefunction(func) + + functools.update_wrapper(self, func) + + if signature is not None: + # Parse the signature string into a Python data structure. + # For instance "(m),(s)->(s,m)" becomes `([(m,),(s,)],[(s,m)])`. + self._in_and_out_core_dimensions = _parse_gufunc_signature(signature) + else: + self._in_and_out_core_dimensions = None + + def __call__(self, *args, **kwargs): + """Call the vectorized function.""" + if not args and not kwargs: + return self.call_thunk() + elif self.signature is not None: + return self.call_with_signature(*args, **kwargs) + else: + return self.call_no_signature(*args, **kwargs) + + def call_thunk(self): + """Call a vectorized thunk. + + Thunks have no arguments and can thus be called directly. + + """ + if self.is_coroutine_fn: + loop = asyncio.new_event_loop() + try: + outputs = loop.run_until_complete(self.func()) + finally: + loop.close() + else: + outputs = self.func() + + return outputs + + def call_no_signature(self, *args, **kwargs): + """Call functions and coroutines when no signature is specified. + + When no signature is specified we assume that all of the function's + inputs and outputs are scalars (core dimension of zero). We first + broadcast the input arrays, then iteratively apply the function over the + elements of the broadcasted arrays and finally reshape the results to + match the input shape. + + Functions are executed in a for loop, coroutines are executed + concurrently. + + """ + # Convert args and kwargs to arrays + args = [np.array(arg) for arg in args] + kwargs = {key: np.array(value) for key, value in kwargs.items()} + + # Broadcast args and kwargs + broadcast_shape = np.broadcast(*args, *list(kwargs.values())).shape + args = [np.broadcast_to(arg, broadcast_shape) for arg in args] + kwargs = { + key: np.broadcast_to(value, broadcast_shape) + for key, value in kwargs.items() + } + + # Execute functions in a loop, and coroutines concurrently + if self.is_coroutine_fn: + outputs = self.vectorize_call_coroutine(broadcast_shape, args, kwargs) + else: + outputs = self.vectorize_call(broadcast_shape, args, kwargs) + + # `outputs` is a flat array or a tuple of flat arrays. We reshape the arrays + # to match the input shape. + outputs = [ + results if isinstance(results, tuple) else (results,) for results in outputs + ] + outputs = tuple( + [np.asarray(x).reshape(broadcast_shape).squeeze() for x in zip(*outputs)] + ) + outputs = tuple([x.item() if np.ndim(x) == 0 else x for x in outputs]) + + n_results = len(list(outputs)) + + return outputs[0] if n_results == 1 else outputs + + def call_with_signature(self, *args, **kwargs): + """Call functions and coroutines when a signature is specified.""" + input_core_dims, output_core_dims = self._in_and_out_core_dimensions + + # Make sure that the numbers of arguments passed is compatible with + # the signature. + num_args = len(args) + len(kwargs) + if num_args != len(input_core_dims): + raise TypeError( + "wrong number of positional arguments: " + "expected %r, got %r" % (len(input_core_dims), len(args)) + ) + + # Convert args and kwargs to arrays + args = [np.asarray(arg) for arg in args] + kwargs = {key: np.array(value) for key, value in kwargs.items()} + + # Find the arguments' broadcast shape, and map placeholder + # variables in the signature to the number of dimensions + # they correspond to given the arguments. + broadcast_shape, dim_sizes = _parse_input_dimensions( + args + list(kwargs.values()), input_core_dims + ) + + # Calculate the shape to which each of the arguments should be broadcasted + # and reshape them accordingly. + input_shapes = _calculate_shapes(broadcast_shape, dim_sizes, input_core_dims) + args = [ + np.broadcast_to(arg, shape, subok=True) + for arg, shape in zip(args, input_shapes) + ] + kwargs = { + key: np.broadcast_to(value, broadcast_shape) + for key, value in kwargs.items() + } + + n_out = len(output_core_dims) + + if self.is_coroutine_fn: + outputs = self.vectorize_call_coroutine(broadcast_shape, args, kwargs) + else: + outputs = self.vectorize_call(broadcast_shape, args, kwargs) + + outputs = [ + results if isinstance(results, tuple) else (results,) for results in outputs + ] + + flat_outputs = list(zip(*outputs)) + n_results = len(flat_outputs) + + if n_out != n_results: + raise ValueError( + f"wrong number of outputs from the function, expected {n_out}, got {n_results}" + ) + + # The number of dimensions of the outputs are not necessarily known in + # advance. The following iterates over the results and updates the + # number of dimensions of the outputs accordingly. + for results, core_dims in zip(flat_outputs, output_core_dims): + for result in results: + _update_dim_sizes(dim_sizes, result, core_dims) + + # Calculate the shape to which each of the outputs should be broadcasted + # and reshape them. + shapes = _calculate_shapes(broadcast_shape, dim_sizes, output_core_dims) + outputs = tuple( + [ + np.hstack(results).reshape(shape).squeeze() + for shape, results in zip(shapes, zip(*outputs)) + ] + ) + outputs = tuple([x.item() if np.ndim(x) == 0 else x for x in outputs]) + + return outputs[0] if n_results == 1 else outputs + + def vectorize_call(self, broadcast_shape, args, kwargs): + """Run the function in a for loop. + + A possible extension would be to parallelize the calls. + + Parameters + ---------- + broadcast_shape + The brodcast shape of the input arrays. + args + The function's broadcasted arguments. + kwargs + The function's broadcasted keyword arguments. + + """ + outputs = [] + for index in np.ndindex(*broadcast_shape): + current_args = tuple(arg[index] for arg in args) + current_kwargs = {key: value[index] for key, value in kwargs.items()} + outputs.append(self.func(*current_args, **current_kwargs)) + + return outputs + + def vectorize_call_coroutine(self, broadcast_shape, args, kwargs): + """Run coroutines concurrently. + + Creates as many tasks as needed and executes them in a new event + loop. + + Parameters + ---------- + broadcast_shape + The brodcast shape of the input arrays. + args + The function's broadcasted arguments. + kwargs + The function's broadcasted keyword arguments. + + """ + + async def create_and_gather_tasks(): + tasks = [] + for index in np.ndindex(*broadcast_shape): + current_args = tuple(arg[index] for arg in args) + current_kwargs = {key: value[index] for key, value in kwargs.items()} + tasks.append(self.func(*current_args, **current_kwargs)) + + outputs = await asyncio.gather(*tasks) + + return outputs + + loop = asyncio.new_event_loop() + try: + outputs = loop.run_until_complete(create_and_gather_tasks()) + finally: + loop.close() + + return outputs + + +def _update_arrays_type(arrays, results): + """Update the dtype of arrays. + + String arrays contain strings of fixed length. Here they are initialized with + the type of the first results, so that if the next results contain longer + strings they will be truncated when added to the output arrays. Here we + update the type if the current results contain longer strings than in the + current output array. + + Parameters + ---------- + arrays + Arrays that contain the vectorized function's results. + results + The current output of the function being vectorized. + + """ + + updated_arrays = [] + for array, result in zip(arrays, results): + if array.dtype.type == np.str_: + if array.dtype < np.array(result).dtype: + array = array.astype(np.array(result).dtype) + + updated_arrays.append(array) + + return tuple(updated_arrays) diff --git a/build/lib/outlines/caching.py b/build/lib/outlines/caching.py new file mode 100644 index 000000000..95392c7e8 --- /dev/null +++ b/build/lib/outlines/caching.py @@ -0,0 +1,179 @@ +import asyncio +import contextlib +import functools +import os +from typing import Callable, Optional + +import cloudpickle +from diskcache import Cache, Disk +from diskcache.core import ENOVAL, UNKNOWN, args_to_key, full_name + +_caching_enabled = True + + +class CloudpickleDisk(Disk): + def __init__(self, directory, compress_level=1, **kwargs): + self.compress_level = compress_level + super().__init__(directory, **kwargs) + + def put(self, key): + data = cloudpickle.dumps(key) + return super().put(data) + + def get(self, key, raw): + data = super().get(key, raw) + return cloudpickle.loads(data) + + def store(self, value, read, key=UNKNOWN): + if not read: + value = cloudpickle.dumps(value) + return super().store(value, read, key=key) + + def fetch(self, mode, filename, value, read): + data = super().fetch(mode, filename, value, read) + if not read: + data = cloudpickle.loads(data) + return data + + +@functools.lru_cache(1) +def get_cache(): + """Get the context object that contains previously-computed return values. + + The cache is used to avoid unnecessary computations and API calls, which can + be long and expensive for large models. + + The cache directory defaults to `HOMEDIR/.cache/outlines`, but this choice + can be overridden by the user by setting the value of the `OUTLINES_CACHE_DIR` + environment variable. + + """ + from outlines._version import __version__ as outlines_version # type: ignore + + home_dir = os.path.expanduser("~") + cache_dir = os.environ.get("OUTLINES_CACHE_DIR", f"{home_dir}/.cache/outlines") + memory = Cache( + cache_dir, + eviction_policy="none", + cull_limit=0, + disk=CloudpickleDisk, + ) + + # ensure if version upgrade occurs, old cache is pruned + if outlines_version != memory.get("__version__"): + memory.clear() + memory["__version__"] = outlines_version + + return memory + + +def cache(expire: Optional[float] = None, typed=False, ignore=()): + """Caching decorator for memoizing function calls. + + The cache key is created based on the values returned by the key_function callable + if provided or based on the arguments of the decorated function directly otherwise + + This is based on `diskcache`'s `memoize`. + + Parameters + ---------- + expire + Seconds until arguments expire. + typed + Cache different types separately. + ignore + Positional or keyword arguments to ignore. + + Returns + ------- + A decorator function that can be applied to other functions. + """ + + def decorator(cached_function: Callable): + memory = get_cache() + + base = (full_name(cached_function),) + + if asyncio.iscoroutinefunction(cached_function): + + async def wrapper(*args, **kwargs): + if not _caching_enabled: + return await cached_function(*args, **kwargs) + + cache_key = wrapper.__cache_key__(*args, **kwargs) + result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True) + + if result is ENOVAL: + result = await cached_function(*args, **kwargs) + wrapper.__memory__.set(cache_key, result, expire, retry=True) + + return result + + else: + + def wrapper(*args, **kwargs): + if not _caching_enabled: + return cached_function(*args, **kwargs) + + cache_key = wrapper.__cache_key__(*args, **kwargs) + result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True) + + if result is ENOVAL: + result = cached_function(*args, **kwargs) + wrapper.__memory__.set(cache_key, result, expire, retry=True) + + return result + + def __cache_key__(*args, **kwargs): + """Make key for cache given function arguments.""" + return args_to_key(base, args, kwargs, typed, ignore) + + wrapper.__cache_key__ = __cache_key__ # type: ignore + wrapper.__memory__ = memory # type: ignore + wrapper.__wrapped__ = cached_function # type: ignore + + return wrapper + + return decorator + + +def disable_cache(): + """Disable the cache for this session. + + Generative models output different results each time they are called when + sampling. This can be a desirable property for some workflows, in which case + one can call `outlines.call.disable` to disable the cache for the session. + + This function does not delete the cache, call `outlines.cache.clear` + instead. It also does not overwrite the cache with the values returned + during the session. + + Example + ------- + + `outlines.cache.disable` should be called right after importing outlines: + + >>> import outlines.cache as cache + >>> cache.disable() + + """ + global _caching_enabled + _caching_enabled = False + + +def clear_cache(): + """Erase the cache completely.""" + memory = get_cache() + memory.clear() + + +@contextlib.contextmanager +def cache_disabled(): + # outlines.caching._caching_enabled + global _caching_enabled + original_state = _caching_enabled + _caching_enabled = False + try: + yield + finally: + _caching_enabled = original_state diff --git a/build/lib/outlines/fsm/__init__.py b/build/lib/outlines/fsm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/build/lib/outlines/fsm/fsm.py b/build/lib/outlines/fsm/fsm.py new file mode 100644 index 000000000..4a7fce8c9 --- /dev/null +++ b/build/lib/outlines/fsm/fsm.py @@ -0,0 +1,69 @@ +import warnings +from typing import TYPE_CHECKING, Iterable, NewType, Optional + +from outlines.fsm.guide import CFGGuide, RegexGuide, StopAtEOSGuide + +if TYPE_CHECKING: + from outlines.models.tokenizer import Tokenizer + +FSMState = NewType("FSMState", int) + + +class StopAtEosFSM(StopAtEOSGuide): + """FSM to generate text until EOS has been generated.""" + + def __init__(self, tokenizer: "Tokenizer"): + warnings.warn( + UserWarning( + "The `StopAtTokenFSM` interface is deprecated and will be removed on 2024-06-01. Please use `StopAtEOSGuide` instead." + ) + ) + super().__init__(tokenizer) + + def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: + next_instruction = self.get_next_instruction(state) + return next_instruction.tokens + + def next_state(self, state: FSMState, token_id: int) -> FSMState: + return FSMState(self.get_next_state(state, token_id)) + + +class RegexFSM(RegexGuide): + """FSM to generate text that is in the language of a regular expression.""" + + def __init__(self, regex_string: str, tokenizer): + warnings.warn( + UserWarning( + "The `RegexFSM` interface is deprecated and will be removed on 2024-06-01. Please use `RegexGuide` instead." + ) + ) + super().__init__(regex_string, tokenizer) + + def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: + next_instruction = self.get_next_instruction(state) + return next_instruction.tokens + + def next_state(self, state: FSMState, token_id: int) -> FSMState: + return FSMState(self.get_next_state(state, token_id)) + + +class CFGFSM(CFGGuide): + """FSM to generate text that is in the language of a context-free grammar.""" + + def __init__(self, cfg_string: str, tokenizer): + warnings.warn( + UserWarning( + "The `CFGFSM` interface is deprecated and will be removed on 2024-06-01. Please use `CFGGuide` instead." + ) + ) + super().__init__(cfg_string, tokenizer) + + def allowed_token_ids(self, state: FSMState) -> Optional[Iterable[int]]: + return self.get_next_instruction(state).tokens + + def next_state(self, state: FSMState, token_id: int) -> FSMState: + return FSMState(self.get_next_state(state, token_id)) + + def copy(self) -> "CFGFSM": + """Create a copy of the FSM.""" + return CFGFSM(self.cfg_string, self.tokenizer) diff --git a/build/lib/outlines/fsm/guide.py b/build/lib/outlines/fsm/guide.py new file mode 100644 index 000000000..d247db62b --- /dev/null +++ b/build/lib/outlines/fsm/guide.py @@ -0,0 +1,426 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Protocol, Tuple, Union + +import interegular +from lark import Lark + +from outlines import grammars +from outlines.caching import cache +from outlines.fsm.regex import ( + create_fsm_index_tokenizer, + make_byte_level_fsm, + make_deterministic_fsm, +) + +if TYPE_CHECKING: + from outlines.models.tokenizer import Tokenizer + + +@dataclass(frozen=True) +class Write: + """Write instruction. + + Attributes + ---------- + tokens + The sequence of tokens to be added to the current sequence by the + generation process. + + """ + + tokens: List[int] + + +@dataclass(frozen=True) +class Generate: + """Generate instruction + + Attributes + ---------- + tokens + The tokens that lead to a valid completion if generated. A value + of ``None`` indicates that all tokens are allowed. + """ + + tokens: Optional[List[int]] + + +Instruction = Union[Write, Generate] + + +class Guide(Protocol): + """Base definition of a generation guide. + + A generation guide defines the behavior of a finite-state machine that guides + a text generation procedure. Unlike the DFAs built from regular expressions + guides can also emit a `Write` instructions which tells the model that it can + append a sequence of tokens (or token word) instead of generating it. + + """ + + def get_next_instruction(self, state: int) -> Instruction: + ... + + def get_next_state(self, state: int, token_id: int) -> int: + ... + + def is_final_state(self, state: int) -> bool: + ... + + def copy(self) -> "Guide": + ... + + +class StopAtEOSGuide(Guide): + """Guide to generate tokens until the EOS token has been generated.""" + + final_state = 1 + start_state = 0 + + def __init__(self, tokenizer: "Tokenizer"): + """Initialize the generation guide. + + model + The logit generator used to generate the next token. + + """ + self.eos_token_id = tokenizer.eos_token_id + self.vocabulary = tokenizer.vocabulary.values() + + def get_next_instruction(self, state: int) -> Instruction: + if self.is_final_state(state): + return Write([self.eos_token_id]) + return Generate(None) + + def get_next_state(self, state: int, token_id: int) -> int: + if token_id == self.eos_token_id or state == self.final_state: + return self.final_state + + return self.start_state + + def is_final_state(self, state: int): + return state == self.final_state + + def copy(self): + return self + + +@cache() +def create_states_mapping( + regex_string: str, tokenizer: "Tokenizer" +) -> Tuple[dict, set, set]: + """Create the variables related to the mapping between states and tokens + The parameters of the function are used for caching purpose + """ + regex_pattern = interegular.parse_pattern(regex_string) + byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True) + regex_fsm, _ = make_deterministic_fsm(byte_fsm) + states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( + regex_fsm, tokenizer + ) + + # We make sure that it is possible to generate strings in the language + # of the regular expression with the tokens present in the model's + # vocabulary. + if not any( + regex_fsm.finals.intersection(v.values()) for v in states_to_token_maps.values() + ): + raise ValueError( + "The vocabulary does not allow us to build a sequence that matches the input regex" + ) + + return states_to_token_maps, empty_token_ids, regex_fsm.finals + + +class RegexGuide(Guide): + """Guide to generate text in the language of a regular expression.""" + + initial_state = 0 + + def __init__(self, regex_string: str, tokenizer): + ( + self.states_to_token_maps, + self.empty_token_ids, + fsm_finals, + ) = create_states_mapping(regex_string, tokenizer) + self.eos_token_id = tokenizer.eos_token_id + self.final_states = fsm_finals | {-1} + + def get_next_instruction(self, state: int) -> Instruction: + """Return the next instruction for guided generation. + + The initialization of the guide builds an index which maps FSM states to a + map from authorized tokens to the state in which the guide needs to move + if said token is generated. Therefore the authorized tokens at the + current state are the keys of the map returned by the value of the index + for current state. + + If the current state is not contained in the end this means that we are + in a final state of the guide. We only authorize EOS tokens in the final + state. + + Parameters + ---------- + state + The current state of the guide. + + Returns + ------- + A `Generate` instance that contains the model and the allowed token ids. + + """ + next_tokens_to_end_states = self.states_to_token_maps.get(state) + if next_tokens_to_end_states is None: + return Write([self.eos_token_id]) + + return Generate(list(next_tokens_to_end_states.keys())) + + def get_next_state(self, state: int, token_id: int) -> int: + """Update the state of the guide. + + We use the index to determine to which state the guide should transition + given the token that was just generated. + + Parameters + ---------- + state + The current state of the guide. + token_id + The id of the token that was just generated. + + Returns + ------- + The new state of the guide. + + """ + if token_id == self.eos_token_id or state not in self.states_to_token_maps: + return -1 + + last_token_to_end_state = self.states_to_token_maps[state] + next_state = last_token_to_end_state.get(token_id) + if next_state is None: + next_state = -1 + + return next_state + + @classmethod + def from_interegular_fsm( + cls, interegular_fsm: interegular.fsm.FSM, tokenizer: "Tokenizer" + ): + from_interegular_instance = cls.__new__(cls) + + def create_states_mapping_from_interegular_fsm( + fsm: interegular.fsm.FSM, + ) -> Tuple[dict, set]: + """Create the variables related to the mapping between states and tokens + The parameters of the function are used for caching purpose + """ + byte_fsm = make_byte_level_fsm(fsm.reduce(), keep_utf8=True) + regex_fsm, _ = make_deterministic_fsm(byte_fsm) + states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer( + regex_fsm, tokenizer + ) + + # We make sure that it is possible to generate strings in the language + # of the regular expression with the tokens present in the model's + # vocabulary. + if not any( + regex_fsm.finals.intersection(v.values()) + for v in states_to_token_maps.values() + ): + raise ValueError( + "The vocabulary does not allow us to build a sequence that matches the input regex" + ) + + return states_to_token_maps, empty_token_ids + + ( + from_interegular_instance.states_to_token_maps, + from_interegular_instance.empty_token_ids, + ) = create_states_mapping_from_interegular_fsm(interegular_fsm) + from_interegular_instance.eos_token_id = tokenizer.eos_token_id + return from_interegular_instance + + def is_final_state(self, state: int) -> bool: + """Determine whether the current state of the guide is a final state.""" + return state in self.final_states + + def copy(self): + return self + + +class CFGGuide(Guide): + """Guide to generate text that is in the language of a context-free grammar.""" + + def __init__(self, cfg_string: str, tokenizer): + self.cfg_string = cfg_string + self.tokenizer = tokenizer + + self.parser = Lark( + cfg_string, + parser="lalr", + lexer="contextual", + propagate_positions=False, + maybe_placeholders=False, + regex=True, + import_paths=[grammars.GRAMMAR_PATH], + ) + self.terminal_regexps = dict() + for terminal in self.parser.terminals: + if terminal.pattern is not None: + self.terminal_regexps[terminal.name] = terminal.pattern.to_regexp() + self.terminal_regexps["$END"] = tokenizer.eos_token + + self.generation = "" + self.reset_state = False + self.allow_eos = False + self.regex_fsm: RegexGuide + + self.check_last = False + self.proposal_last: List[int] = [] + self.regex_fsm_last: RegexGuide + + self.start_state = 0 + self.final_state = -1 + + def get_next_instruction(self, state: int) -> Instruction: + """Generate an instruction for the next step. + + Upon initialization, the CFG incremental parser is used to determine the + first regex and construct the first FSM to generate the first terminal. + + This FSM is used for proposals until either: + + - The FSM is exhausted, and its only remaining option is the EOS token, + in which case we feed the generated terminal to the + CFG incremental parser and allow it to propose the next regex + corresponding to the next set of valid terminals. + - The current FSM can be exhausted, but the EOS token is not the only + remaining option. In this case we allow proposal of current terminal + extensions, store the current FSM and its state, then also use the CFG + parser to propose a new regex corresponding to terminating the current + terminal and starting the next one. The model can then sample from + either of these sets to determine whether to extend the current + terminal or terminate it and start the next one. + + The CFG incremental parser is allowed to propose the EOS token from any accepting state, + and once it is generated, the FSM will continue to always generate the EOS token. + + Parameters + ---------- + state + The current state of the FSM. + + Returns + ------- + A list that contains the tokens to mask. + + """ + if self.is_final_state(state): + return Write([self.tokenizer.eos_token_id]) + + proposal: List[int] = [] + if self.generation != "": + if self.check_last: + proposer = self.regex_fsm_last + else: + proposer = self.regex_fsm + + instruction = proposer.get_next_instruction(state) + + assert instruction.tokens is not None + + if isinstance(instruction, Write): + proposal += instruction.tokens + else: + proposal += instruction.tokens + + if self.tokenizer.eos_token_id not in proposal: + return Generate(proposal) + + self.check_last = False + proposal = [x for x in proposal if x != self.tokenizer.eos_token_id] + if len(proposal) > 0: + self.check_last = True + self.proposal_last = proposal.copy() + self.regex_fsm_last = proposer + + interactive = self.parser.parse_interactive(self.generation) + interactive.exhaust_lexer() + + options = {self.terminal_regexps[x] for x in interactive.accepts()} + # add %ignore terminals + options |= {self.terminal_regexps[x] for x in self.parser.lexer_conf.ignore} + + if self.terminal_regexps["$END"] in options: + options.remove(self.terminal_regexps["$END"]) + if len(options) == 0: + return Write([self.tokenizer.eos_token_id]) + self.allow_eos = True + options.add("") + assert len(options) > 1 + + regex_string = r"(" + r"|".join([r"(" + x + r")" for x in options]) + r")" + self.regex_fsm = RegexGuide(regex_string, self.tokenizer) + self.reset_state = True + + instruction = self.regex_fsm.get_next_instruction(self.start_state) + + assert instruction.tokens is not None + + if isinstance(instruction, Write): + proposal += instruction.tokens + else: + proposal += instruction.tokens + + if self.allow_eos: + self.allow_eos = False + else: + proposal = [x for x in proposal if x != self.tokenizer.eos_token_id] + assert len(proposal) > 0 + + return Generate(proposal) + + def get_next_state(self, state: int, token_id: int) -> int: + """Update the state of the guide. + + Transitions the underlying regex FSM to its next state. + If at max tokens or EOS token, transition permanently to the final state. + Update stored partial generations for subsequent incremental parsing. + + Parameters + ---------- + state + The current state of the FSM. + token_id + The id of the token that was just generated. + + Returns + ------- + The new state of the FSM. + """ + + # We need to return the final state when in the final state because we + # then generate EOS tokens instead of stopping the generation. + if token_id == self.tokenizer.eos_token_id or state == self.final_state: + return self.final_state + + self.generation += self.tokenizer.decode([token_id])[0] + + if self.check_last: + if token_id in self.proposal_last: + return self.regex_fsm_last.get_next_state(state, token_id) + self.check_last = False + + if self.reset_state: + self.reset_state = False + state = self.start_state + + return self.regex_fsm.get_next_state(state, token_id) + + def is_final_state(self, state: int) -> bool: + return state == self.final_state + + def copy(self) -> "CFGGuide": + """Create a copy of the FSM.""" + return CFGGuide(self.cfg_string, self.tokenizer) diff --git a/build/lib/outlines/fsm/json_schema.py b/build/lib/outlines/fsm/json_schema.py new file mode 100644 index 000000000..272fee5eb --- /dev/null +++ b/build/lib/outlines/fsm/json_schema.py @@ -0,0 +1,436 @@ +import inspect +import json +import re +import warnings +from typing import Callable, Optional + +from jsonschema.protocols import Validator +from pydantic import create_model +from referencing import Registry, Resource +from referencing._core import Resolver +from referencing.jsonschema import DRAFT202012 + +STRING_INNER = r'([^"\\\x00-\x1f\x7f-\x9f]|\\\S)' +STRING = f'"{STRING_INNER}*"' +INTEGER = r"(-)?(0|[1-9][0-9]*)" +NUMBER = rf"({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?" +BOOLEAN = r"(true|false)" +NULL = r"null" +WHITESPACE = r"[ ]?" + +type_to_regex = { + "string": STRING, + "integer": INTEGER, + "number": NUMBER, + "boolean": BOOLEAN, + "null": NULL, +} + +DATE_TIME = r'"(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\.[0-9]{3})?(Z)?"' +DATE = r'"(?:\d{4})-(?:0[1-9]|1[0-2])-(?:0[1-9]|[1-2][0-9]|3[0-1])"' +TIME = r'"(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\\.[0-9]+)?(Z)?"' +UUID = r'"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"' + +format_to_regex = { + "uuid": UUID, + "date-time": DATE_TIME, + "date": DATE, + "time": TIME, +} + + +def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = None): + """Turn a JSON schema into a regex that matches any JSON object that follows + this schema. + + JSON Schema is a declarative language that allows to annotate JSON documents + with types and descriptions. These schemas can be generated from any Python + datastructure that has type annotation: namedtuples, dataclasses, Pydantic + models. And by ensuring that the generation respects the schema we ensure + that the output can be parsed into these objects. + This function parses the provided schema and builds a generation schedule which + mixes deterministic generation (fixed strings), and sampling with constraints. + + Parameters + ---------- + schema + A string that represents a JSON Schema. + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string literals) + Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + + Returns + ------- + A generation schedule. A list of strings that represent the JSON + schema's structure and regular expression that define the structure of + the fields. + + References + ---------- + .. [0] JSON Schema. https://json-schema.org/ + + """ + + schema = json.loads(schema) + Validator.check_schema(schema) + + # Build reference resolver + schema = Resource(contents=schema, specification=DRAFT202012) + uri = schema.id() if schema.id() is not None else "" + registry = Registry().with_resource(uri=uri, resource=schema) + resolver = registry.resolver() + + content = schema.contents + return to_regex(resolver, content, whitespace_pattern) + + +def _get_num_items_pattern(min_items, max_items, whitespace_pattern): + # Helper function for arrays and objects + min_items = int(min_items or 0) + if max_items is None: + return rf"{{{max(min_items - 1, 0)},}}" + else: + max_items = int(max_items) + if max_items < 1: + return None + return rf"{{{max(min_items - 1, 0)},{max_items - 1}}}" + + +def to_regex( + resolver: Resolver, instance: dict, whitespace_pattern: Optional[str] = None +): + """Translate a JSON Schema instance into a regex that validates the schema. + + Note + ---- + Many features of JSON schema are missing: + - Handle `additionalProperties` keyword + - Handle types defined as a list + - Handle constraints on numbers + - Handle special patterns: `date`, `uri`, etc. + + This does not support recursive definitions. + + Parameters + ---------- + resolver + An object that resolves references to other instances within a schema + instance + The instance to translate + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string literals) + Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + """ + + # set whitespace pattern + if whitespace_pattern is None: + whitespace_pattern = WHITESPACE + + if instance == {}: + # JSON Schema Spec: Empty object means unconstrained, any json type is legal + types = [ + {"type": "boolean"}, + {"type": "null"}, + {"type": "number"}, + {"type": "integer"}, + {"type": "string"}, + {"type": "array"}, + {"type": "object"}, + ] + regexes = [to_regex(resolver, t, whitespace_pattern) for t in types] + regexes = [rf"({r})" for r in regexes] + return rf"{'|'.join(regexes)}" + + elif "properties" in instance: + regex = "" + regex += r"\{" + properties = instance["properties"] + required_properties = instance.get("required", []) + is_required = [item in required_properties for item in properties] + # If at least one property is required, we include the one in the lastest position + # without any comma. + # For each property before it (optional or required), we add with a comma after the property. + # For each property after it (optional), we add with a comma before the property. + if any(is_required): + last_required_pos = max([i for i, value in enumerate(is_required) if value]) + for i, (name, value) in enumerate(properties.items()): + subregex = f'{whitespace_pattern}"{re.escape(name)}"{whitespace_pattern}:{whitespace_pattern}' + subregex += to_regex(resolver, value, whitespace_pattern) + if i < last_required_pos: + subregex = f"{subregex}{whitespace_pattern}," + elif i > last_required_pos: + subregex = f"{whitespace_pattern},{subregex}" + regex += subregex if is_required[i] else f"({subregex})?" + # If no property is required, we have to create a possible pattern for each property in which + # it's the last one necessarilly present. Then, we add the others as optional before and after + # following the same strategy as described above. + # The whole block is made optional to allow the case in which no property is returned. + else: + property_subregexes = [] + for i, (name, value) in enumerate(properties.items()): + subregex = f'{whitespace_pattern}"{name}"{whitespace_pattern}:{whitespace_pattern}' + subregex += to_regex(resolver, value, whitespace_pattern) + property_subregexes.append(subregex) + possible_patterns = [] + for i in range(len(property_subregexes)): + pattern = "" + for subregex in property_subregexes[:i]: + pattern += f"({subregex}{whitespace_pattern},)?" + pattern += property_subregexes[i] + for subregex in property_subregexes[i + 1 :]: + pattern += f"({whitespace_pattern},{subregex})?" + possible_patterns.append(pattern) + regex += f"({'|'.join(possible_patterns)})?" + + regex += f"{whitespace_pattern}" + r"\}" + + return regex + + # To validate against allOf, the given data must be valid against all of the + # given subschemas. + elif "allOf" in instance: + subregexes = [ + to_regex(resolver, t, whitespace_pattern) for t in instance["allOf"] + ] + subregexes_str = [f"{subregex}" for subregex in subregexes] + return rf"({''.join(subregexes_str)})" + + # To validate against `anyOf`, the given data must be valid against + # any (one or more) of the given subschemas. + elif "anyOf" in instance: + subregexes = [ + to_regex(resolver, t, whitespace_pattern) for t in instance["anyOf"] + ] + return rf"({'|'.join(subregexes)})" + + # To validate against oneOf, the given data must be valid against exactly + # one of the given subschemas. + elif "oneOf" in instance: + subregexes = [ + to_regex(resolver, t, whitespace_pattern) for t in instance["oneOf"] + ] + + xor_patterns = [f"(?:{subregex})" for subregex in subregexes] + + return rf"({'|'.join(xor_patterns)})" + + # Create pattern for Tuples, per JSON Schema spec, `prefixItems` determines types at each idx + elif "prefixItems" in instance: + element_patterns = [ + to_regex(resolver, t, whitespace_pattern) for t in instance["prefixItems"] + ] + comma_split_pattern = rf"{whitespace_pattern},{whitespace_pattern}" + tuple_inner = comma_split_pattern.join(element_patterns) + return rf"\[{whitespace_pattern}{tuple_inner}{whitespace_pattern}\]" + + # The enum keyword is used to restrict a value to a fixed set of values. It + # must be an array with at least one element, where each element is unique. + elif "enum" in instance: + choices = [] + for choice in instance["enum"]: + if type(choice) in [int, float, bool, None]: + choices.append(re.escape(str(choice))) + elif type(choice) == str: + choices.append(f'"{re.escape(choice)}"') + + return f"({'|'.join(choices)})" + + elif "const" in instance: + const = instance["const"] + if type(const) in [int, float, bool, None]: + const = re.escape(str(const)) + elif type(const) == str: + const = f'"{re.escape(const)}"' + return const + + elif "$ref" in instance: + path = f"{instance['$ref']}" + instance = resolver.lookup(path).contents + return to_regex(resolver, instance, whitespace_pattern) + + # The type keyword may either be a string or an array: + # - If it's a string, it is the name of one of the basic types. + # - If it is an array, it must be an array of strings, where each string is + # the name of one of the basic types, and each element is unique. In this + # case, the JSON snippet is valid if it matches any of the given types. + elif "type" in instance: + instance_type = instance["type"] + if instance_type == "string": + if "maxLength" in instance or "minLength" in instance: + max_items = instance.get("maxLength", "") + min_items = instance.get("minLength", "") + try: + if int(max_items) < int(min_items): + raise ValueError( + "maxLength must be greater than or equal to minLength" + ) + except ValueError: + pass + return f'"{STRING_INNER}{{{min_items},{max_items}}}"' + elif "pattern" in instance: + pattern = instance["pattern"] + if pattern[0] == "^" and pattern[-1] == "$": + return rf'(^"{pattern[1:-1]}"$)' + else: + return rf'("{pattern}")' + elif "format" in instance: + format = instance["format"] + if format == "date-time": + return format_to_regex["date-time"] + elif format == "uuid": + return format_to_regex["uuid"] + elif format == "date": + return format_to_regex["date"] + elif format == "time": + return format_to_regex["time"] + else: + raise NotImplementedError( + f"Format {format} is not supported by Outlines" + ) + else: + return type_to_regex["string"] + + elif instance_type == "number": + return type_to_regex["number"] + + elif instance_type == "integer": + return type_to_regex["integer"] + + elif instance_type == "array": + num_repeats = _get_num_items_pattern( + instance.get("minItems"), instance.get("maxItems"), whitespace_pattern + ) + if num_repeats is None: + return rf"\[{whitespace_pattern}\]" + + allow_empty = "?" if int(instance.get("minItems", 0)) == 0 else "" + + if "items" in instance: + items_regex = to_regex(resolver, instance["items"], whitespace_pattern) + return rf"\[{whitespace_pattern}(({items_regex})(,{whitespace_pattern}({items_regex})){num_repeats}){allow_empty}{whitespace_pattern}\]" + else: + # Here we need to make the choice to exclude generating list of objects + # if the specification of the object is not given, even though a JSON + # object that contains an object here would be valid under the specification. + legal_types = [ + {"type": "boolean"}, + {"type": "null"}, + {"type": "number"}, + {"type": "integer"}, + {"type": "string"}, + ] + depth = instance.get("depth", 2) + if depth > 0: + legal_types.append({"type": "object", "depth": depth - 1}) + legal_types.append({"type": "array", "depth": depth - 1}) + + regexes = [ + to_regex(resolver, t, whitespace_pattern) for t in legal_types + ] + return rf"\[{whitespace_pattern}({'|'.join(regexes)})(,{whitespace_pattern}({'|'.join(regexes)})){num_repeats}{allow_empty}{whitespace_pattern}\]" + + elif instance_type == "object": + # pattern for json object with values defined by instance["additionalProperties"] + # enforces value type constraints recursively, "minProperties", and "maxProperties" + # doesn't enforce "required", "dependencies", "propertyNames" "any/all/on Of" + num_repeats = _get_num_items_pattern( + instance.get("minProperties"), + instance.get("maxProperties"), + whitespace_pattern, + ) + if num_repeats is None: + return rf"\{{{whitespace_pattern}\}}" + + allow_empty = "?" if int(instance.get("minProperties", 0)) == 0 else "" + + additional_properties = instance.get("additionalProperties") + + if additional_properties is None or additional_properties is True: + # JSON Schema behavior: If the additionalProperties of an object is + # unset or True, it is unconstrained object. + # We handle this by setting additionalProperties to anyOf: {all types} + + legal_types = [ + {"type": "string"}, + {"type": "number"}, + {"type": "boolean"}, + {"type": "null"}, + ] + + # We set the object depth to 2 to keep the expression finite, but the "depth" + # key is not a true component of the JSON Schema specification. + depth = instance.get("depth", 2) + if depth > 0: + legal_types.append({"type": "object", "depth": depth - 1}) + legal_types.append({"type": "array", "depth": depth - 1}) + additional_properties = {"anyOf": legal_types} + + value_pattern = to_regex( + resolver, additional_properties, whitespace_pattern + ) + key_value_pattern = ( + f"{STRING}{whitespace_pattern}:{whitespace_pattern}{value_pattern}" + ) + key_value_successor_pattern = ( + f"{whitespace_pattern},{whitespace_pattern}{key_value_pattern}" + ) + multiple_key_value_pattern = f"({key_value_pattern}({key_value_successor_pattern}){num_repeats}){allow_empty}" + + return ( + r"\{" + + whitespace_pattern + + multiple_key_value_pattern + + whitespace_pattern + + r"\}" + ) + + elif instance_type == "boolean": + return type_to_regex["boolean"] + + elif instance_type == "null": + return type_to_regex["null"] + + elif isinstance(instance_type, list): + # Here we need to make the choice to exclude generating an object + # if the specification of the object is not give, even though a JSON + # object that contains an object here would be valid under the specification. + regexes = [ + to_regex(resolver, {"type": t}, whitespace_pattern) + for t in instance_type + if t != "object" + ] + return rf"({'|'.join(regexes)})" + + raise NotImplementedError( + f"""Could not translate the instance {instance} to a + regular expression. Make sure it is valid to the JSON Schema specification. If + it is, please open an issue on the Outlines repository""" + ) + + +def get_schema_from_signature(fn: Callable) -> str: + """Turn a function signature into a JSON schema. + + Every JSON object valid to the output JSON Schema can be passed + to `fn` using the ** unpacking syntax. + + """ + signature = inspect.signature(fn) + arguments = {} + for name, arg in signature.parameters.items(): + if arg.annotation == inspect._empty: + raise ValueError("Each argument must have a type annotation") + else: + arguments[name] = (arg.annotation, ...) + + try: + fn_name = fn.__name__ + except Exception as e: + fn_name = "Arguments" + warnings.warn( + f"The function name could not be determined. Using default name 'Arguments' instead. For debugging, here is exact error:\n{e}", + category=UserWarning, + ) + model = create_model(fn_name, **arguments) + + return model.model_json_schema() diff --git a/build/lib/outlines/fsm/parsing.py b/build/lib/outlines/fsm/parsing.py new file mode 100644 index 000000000..e4fa7b764 --- /dev/null +++ b/build/lib/outlines/fsm/parsing.py @@ -0,0 +1,870 @@ +from copy import copy, deepcopy +from dataclasses import dataclass +from functools import lru_cache +from typing import Any, Dict, FrozenSet, Iterator, Optional, Set, Tuple, Union + +import interegular +from interegular.fsm import FSM +from interegular.patterns import Unsupported +from lark import Lark, Token +from lark.common import LexerConf, ParserConf +from lark.exceptions import LexError, UnexpectedInput +from lark.indenter import Indenter +from lark.lexer import ( + BasicLexer, + ContextualLexer, + LexerState, + LexerThread, + Scanner, + UnexpectedCharacters, + UnexpectedToken, + _create_unless, +) +from lark.parser_frontends import ( + ParsingFrontend, + PostLexConnector, + _validate_frontend_args, +) +from lark.parsers.lalr_analysis import ( + Action, + IntParseTable, + LALR_Analyzer, + ParseTable, + Shift, +) +from lark.parsers.lalr_interactive_parser import InteractiveParser +from lark.parsers.lalr_parser import LALR_Parser, ParseConf, ParserState, _Parser + +from outlines.fsm.regex import ( + fsm_union, + get_sub_fsms_from_seq, + get_token_transition_keys, + make_deterministic_fsm, + walk_fsm, +) + +PartialParseState = Tuple[str, int] +ParseStateType = Union[int, FrozenSet] + + +@dataclass +class PartialTerminalInfo: + priority: int + terminal_name: str + can_transition: bool + is_final: bool + + +@dataclass +class PartialTokensInfo: + fsm_state_seq: Tuple[int, ...] + is_not_finished: bool + terminals_and_info: Tuple[PartialTerminalInfo, ...] + final_terminals_and_info: Tuple[PartialTerminalInfo, ...] + + +class PartialParserConf(ParserConf): + __serialize_fields__ = ( + "rules", + "start", + "parser_type", + "deterministic", + "use_value_stack", + ) + + def __init__(self, rules, callbacks, start, deterministic, use_value_stack): + super().__init__(rules, callbacks, start) + self.deterministic = deterministic + self.use_value_stack = use_value_stack + + +class PartialLark(Lark): + __serialize_fields__ = ( + "parser", + "rules", + "options", + "deterministic", + "use_value_stack", + ) + + def __init__(self, grammar, **options): + # TODO: Could've extended `LarkOptions`, but all these extensions are + # already way too much (and brittle). This library really needs a + # complete refactoring. + self.deterministic = options.pop("deterministic", False) + self.use_value_stack = options.pop("use_value_stack", False) + options["regex"] = True + super().__init__(grammar, **options) + assert self.options.parser == "lalr" + + def _build_lexer(self, dont_ignore: bool = False) -> "PartialBasicLexer": + lexer_conf = self.lexer_conf + if dont_ignore: + from copy import copy + + lexer_conf = copy(lexer_conf) + lexer_conf.ignore = () + + return PartialBasicLexer(lexer_conf) + + def _build_parser(self) -> "PartialParsingFrontend": + self._prepare_callbacks() + _validate_frontend_args(self.options.parser, self.options.lexer) + parser_conf = PartialParserConf( + self.rules, + self._callbacks, + self.options.start, + self.deterministic, + self.use_value_stack, + ) + + # This is `_construct_parsing_frontend` expanded/inlined + parser_type = self.options.parser + lexer_type = self.options.lexer + lexer_conf = self.lexer_conf + + assert isinstance(lexer_conf, LexerConf) + assert isinstance(parser_conf, ParserConf) + parser_conf.parser_type = parser_type + self.lexer_conf.lexer_type = lexer_type + return PartialParsingFrontend(lexer_conf, parser_conf, self.options) + + def __repr__(self): + return "{}(open({!r}), parser={!r}, lexer={!r}, ...)".format( + type(self).__name__, + self.source_path, + self.options.parser, + self.options.lexer, + ) + + def parse_from_state(self, parse_state: "PartialParseState", is_end=False): + return self.parser.parser.parser.parse_from_state(parse_state, is_end=is_end) + + +class PartialLexerThread(LexerThread): + def __copy__(self): + return type(self)(copy(self.lexer), copy(self.state)) + + def __repr__(self): + return f"{type(self).__name__}(lexer={self.lexer!r}, state={self.state!r})" + + +class PartialPostLexConnector(PostLexConnector): + def __copy__(self): + return type(self)(self.lexer, copy(self.postlexer)) + + def __repr__(self): + return ( + f"{type(self).__name__}(lexer={self.lexer!r}, postlexer={self.postlexer!r})" + ) + + +class PartialParsingFrontend(ParsingFrontend): + def __init__(self, lexer_conf, parser_conf, options, parser=None): + assert parser_conf.parser_type == "lalr" + + options._plugins["LALR_Parser"] = PartialLALRParser + options._plugins["BasicLexer"] = PartialBasicLexer + options._plugins["ContextualLexer"] = PartialContextualLexer + options._plugins["LexerThread"] = PartialLexerThread + + super().__init__(lexer_conf, parser_conf, options, parser=parser) + + if lexer_conf.postlex: + self.lexer = PartialPostLexConnector(self.lexer.lexer, lexer_conf.postlex) + + self._termset_fsm_info = None + self._symbols_to_states: Optional[ + Dict[str, Set[Tuple[ParseStateType, Action]]] + ] = None + self._reverse_shifts: Optional[ + Dict[ParseStateType, Dict[str, Set[ParseStateType]]] + ] = None + # self._state_transition_map: Optional[ + # Dict[Tuple[ParseStateType, str], Set[ParseStateType]] + # ] = None + + def _compute_maps( + self, + ): + """Compute state transition and symbols-to-states maps.""" + self._reverse_shifts = {} + self._symbols_to_states = {} + + parse_table = self.parser.parser.parse_table + + for from_state, symbols_to_ops in parse_table.states.items(): + for symbol, op in symbols_to_ops.items(): + if op[0] == Shift: + symbols_to_from_states = self._reverse_shifts.setdefault(op[1], {}) + symbols_to_from_states.setdefault(symbol, set()).add(from_state) + self._symbols_to_states.setdefault(symbol, set()).add((from_state, op)) + + # # TODO: This approach is very wasteful. + # context_lexer = get_contextual_lexer(self) + # self._state_transition_map = {} + # + # for from_state, transitions in parse_table.states.items(): + # for symbol, action in transitions.items(): + # # TODO: Filter non-terminals + # if symbol not in context_lexer.root_lexer.terminals_by_name: + # continue + # + # if action[0] is Shift: + # self._state_transition_map.setdefault( + # (from_state, symbol), set() + # ).add(action[1]) + # continue + # + # antecedent_state_seqs = parse_to_terminal(self, [(from_state,)], symbol) + # + # for antecedent_state_seq in antecedent_state_seqs: + # antecedent_state = antecedent_state_seq[-1] + # self._state_transition_map.setdefault( + # (from_state, symbol), set() + # ).add(antecedent_state) + + def _compute_termset_fsm_info(self): + """Collect and return information about terminal symbol sets and their FSMs. + + Terminal symbol sets (or "termsets") are ordered sequences of terminal + symbols that are used by each parser state. Associated with each is a + collection of FSMs for each terminal and a single parse state FSM that is + the union of each terminal's FSM. + + This constructs a list of tuples containing the termset, the set of + parse states that use the termsets, parse state FSMs, and information + mapping the components of the parse state FSMs to their terminal symbol + FSMs. + + """ + context_lexer = get_contextual_lexer(self) + termsets_to_fsms = {} + termsets_to_parse_states: Dict[Tuple[str, ...], Set[ParseStateType]] = {} + for parse_state, lexer in context_lexer.lexers.items(): + scanner = lexer.scanner + key = tuple(term.name for term in scanner.terminals) + termsets_to_fsms[key] = (scanner.fsm, scanner.fsms_to_trans_finals) + termsets_to_parse_states.setdefault(key, set()).add(parse_state) + + self._termset_fsm_info = [ + ( + termset, + frozenset(termsets_to_parse_states[termset]), + fsm, + fsms_to_trans_finals, + ) + for termset, (fsm, fsms_to_trans_finals) in termsets_to_fsms.items() + ] + + @property + def termset_fsm_info(self): + if self._termset_fsm_info is None: + self._compute_termset_fsm_info() + return self._termset_fsm_info + + @property + def symbols_to_states(self): + if self._symbols_to_states is None: + self._compute_maps() + return self._symbols_to_states + + @property + def reverse_shifts(self): + if self._reverse_shifts is None: + self._compute_maps() + return self._reverse_shifts + + # @property + # def state_transition_map(self): + # if self._state_transition_map is None: + # self._compute_maps() + # return self._state_transition_map + + +class PartialLALRParser(LALR_Parser): + def __init__(self, parser_conf, debug=False, strict=False): + analysis = LALR_Analyzer( + parser_conf, debug=debug if not parser_conf.deterministic else True + ) + analysis.compute_lalr() + callbacks = parser_conf.callbacks + + self.parser_conf = parser_conf + self._parse_table = analysis.parse_table + + if parser_conf.deterministic: + old_to_new = {} + + def to_tuple(v): + new = old_to_new.get(v) + if new is None: + new = tuple(sorted(v, key=lambda y: str(y))) + old_to_new[v] = new + return new + + enum = sorted( + self._parse_table.states.keys(), + key=lambda x: str(sorted(x, key=lambda y: str(y))), + ) + + new_states = {} + for s in enum: + transitions = { + term: op if op[0] is not Shift else (op[0], to_tuple(op[1])) + for term, op in self._parse_table.states[s].items() + } + new_states[to_tuple(s)] = transitions + + self._parse_table = type(self._parse_table)( + new_states, + {k: to_tuple(v) for k, v in self._parse_table.start_states.items()}, + {k: to_tuple(v) for k, v in self._parse_table.end_states.items()}, + ) + + if not debug: + self._parse_table = IntParseTable.from_ParseTable(self._parse_table) + self.states_to_rulesets = dict( + zip(self._parse_table.states.keys(), new_states.keys()) + ) + + self.parser = PartialParser( + self._parse_table, + callbacks, + debug, + use_value_stack=parser_conf.use_value_stack, + ) + + @classmethod + def deserialize(cls, data, memo, callbacks, debug=False): + inst = cls.__new__(cls) + inst._parse_table = ParseTable.deserialize(data, memo) + inst.parser = PartialParser(inst._parse_table, callbacks, debug) + return inst + + +class PartialParserState(ParserState): + __slots__ = "use_value_stack" + + def __init__( + self, + parse_conf, + lexer, + state_stack=None, + value_stack=None, + use_value_stack=False, + ): + super().__init__( + parse_conf, lexer, state_stack=state_stack, value_stack=value_stack + ) + self.use_value_stack = use_value_stack + + def feed_token(self, token, is_end=False): + if token.type == "partial": + # If none of the potential terminals can transition, we need to know now + current_state = self.state_stack[-1] + current_lexer = get_contextual_lexer(self.lexer).lexers[current_state] + + # We have to feed the token and determine whether or not at least + # one terminal is consistent with the stack; otherwise, we'll miss + # invalid REDUCE cases. + # TODO: We should track separate parses conditional on possible + # token/symbol types, then we can coherently reuse the following + # results instead of recomputing it later. + can_transition = False + for terminal_info in token.value.terminals_and_info: + if terminal_info.terminal_name not in current_lexer.ignore_types: + test_token = Token.new_borrow_pos( + terminal_info.terminal_name, "", token + ) + + stack = copy(self.state_stack) + try: + self.feed_token_no_stack(test_token, is_end=is_end) + can_transition = True + break + except UnexpectedToken: + continue + finally: + self.state_stack = stack + else: + can_transition = True + + if not can_transition: + expected = { + s + for s in self.parse_conf.states[current_state].keys() + if s.isupper() + } + raise UnexpectedToken( + token, expected, state=self, interactive_parser=None + ) + + elif self.use_value_stack: + super().feed_token(token, is_end=is_end) + else: + self.feed_token_no_stack(token, is_end=is_end) + + def feed_token_no_stack(self, token, is_end=False): + """ + This is a copy of `ParserState.feed_token` with all the value stack + steps removed. Since we're not exactly parsing in order to obtain a + CST or anything similar, we can avoid the growing expense of tracking + the parse tree. + """ + state_stack = self.state_stack + states = self.parse_conf.states + end_state = self.parse_conf.end_state + + while True: + state = state_stack[-1] + try: + action, arg = states[state][token.type] + except KeyError: + expected = {s for s in states[state].keys() if s.isupper()} + raise UnexpectedToken( + token, expected, state=self, interactive_parser=None + ) + + assert arg != end_state + + if action is Shift: + # shift once and return + assert not is_end + state_stack.append(arg) + return + else: + # reduce+shift as many times as necessary + rule = arg + size = len(rule.expansion) + if size: + del state_stack[-size:] + + _action, new_state = states[state_stack[-1]][rule.origin.name] + assert _action is Shift + state_stack.append(new_state) + + if is_end and state_stack[-1] == end_state: + return + + def __copy__(self): + return type(self)( + self.parse_conf, + copy(self.lexer), + copy(self.state_stack), + deepcopy(self.value_stack), + use_value_stack=self.use_value_stack, + ) + + def __repr__(self): + return f"{type(self).__name__}(lexer={self.lexer!r}, state_stack={self.state_stack!r})" + + +class PartialParser(_Parser): + def __init__(self, parse_table, callbacks, debug=False, use_value_stack=False): + super().__init__(parse_table, callbacks, debug=debug) + self.use_value_stack = use_value_stack + + def parse( + self, lexer, start, value_stack=None, state_stack=None, start_interactive=False + ): + parse_conf = ParseConf(self.parse_table, self.callbacks, start) + parser_state = PartialParserState( + parse_conf, copy(lexer), state_stack, value_stack, self.use_value_stack + ) + if start_interactive: + return InteractiveParser(self, parser_state, parser_state.lexer) + return self.parse_from_state(parser_state) + + def parse_from_state(self, state, last_token=None, is_end=False): + try: + token = last_token + for token in state.lexer.lex(state): + state.feed_token(token) + + if is_end and (not token or token.type != "partial"): + end_token = ( + Token.new_borrow_pos("$END", "", token) + if token + else Token("$END", "", 0, 1, 1) + ) + state.feed_token(end_token, True) + + return state + except UnexpectedInput as e: + try: + e.interactive_parser = InteractiveParser(self, state, state.lexer) + except NameError: + pass + raise e + except Exception: + if self.debug: + print("") + print("STATE STACK DUMP") + print("----------------") + for i, s in enumerate(state.state_stack): + print("%d)" % i, s) + print("") + + raise + + +class PartialScanner(Scanner): + @classmethod + @lru_cache + def construct_terminal_fsm(cls, terminal): + # TODO: This should really be done at the lexer/parser level so that + # the lifetime of these objects is tied to the parser itself. + regex_str = terminal.pattern.to_regexp() + pattern = interegular.parse_pattern(regex_str) + fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce()) + return fsm, pattern.prefix_postfix + + def __init__(self, terminals, g_regex_flags, re_, use_bytes, match_whole=False): + self.terminals = terminals + self.g_regex_flags = g_regex_flags + self.use_bytes = use_bytes + self.match_whole = match_whole + self.allowed_types = {t.name for t in self.terminals} + self._mres = None + + fsms = [] + for t in self.terminals: + fsm, prefix_postfix = self.construct_terminal_fsm(t) + + # TODO FIXME: We don't support this right now. + assert prefix_postfix == (0, 0) + + fsms.append(fsm) + + self.fsm, self.fsms_to_trans_finals = fsm_union(fsms) + + def get_terminals_info( + self, fsm_state_seq + ) -> Tuple[Tuple[PartialTerminalInfo, ...], Tuple[PartialTerminalInfo, ...]]: + """Get the possible terminal symbols for an FSM state sequence.""" + terminals_and_info: Tuple[PartialTerminalInfo, ...] = () + final_terminals_and_info: Tuple[PartialTerminalInfo, ...] = () + for i, (fsm_id, fsm_reads_more, in_final) in enumerate( + get_sub_fsms_from_seq(fsm_state_seq, self.fsms_to_trans_finals) + ): + terminal_name = self.terminals[fsm_id].name + info = PartialTerminalInfo(i, terminal_name, fsm_reads_more, in_final) + terminals_and_info += (info,) + if in_final: + final_terminals_and_info += (info,) + + return terminals_and_info, final_terminals_and_info + + def match(self, text, pos, last_fsm_state_seq: Optional[Tuple[int, ...]] = None): + """Determine an FSM match over `text` starting at `pos` and continuing `last_fsm_state_seq`.""" + + start_pos = pos + + if last_fsm_state_seq: + assert len(last_fsm_state_seq) > 1 + start_pos += len(last_fsm_state_seq) - 1 + start_state = last_fsm_state_seq[-1] + else: + start_state = self.fsm.initial + + text_part = text[start_pos:] + + text_transitions = get_token_transition_keys( + self.fsm.fsm_info.alphabet_symbol_mapping, + self.fsm.fsm_info.alphabet_anything_value, + text_part, + ) + + state_seq = walk_fsm( + self.fsm, + text_transitions, + start_state, + full_match=self.match_whole, + ) + + if not state_seq: + return None + + if last_fsm_state_seq: + res = last_fsm_state_seq + tuple(state_seq) + else: + res = (start_state,) + tuple(state_seq) + + return res + + +class PartialContextualLexer(ContextualLexer): + def __init__(self, conf: "LexerConf", states, always_accept=()): + terminals = list(conf.terminals) + terminals_by_name = conf.terminals_by_name + + trad_conf = copy(conf) + trad_conf.terminals = terminals + + lexer_by_symbols: Dict = {} + self.lexers = {} + for state, accepts in states.items(): + key = frozenset(accepts) + try: + lexer = lexer_by_symbols[key] + except KeyError: + accepts = set(accepts) | set(conf.ignore) | set(always_accept) + lexer_conf = copy(trad_conf) + lexer_conf.terminals = [ + terminals_by_name[n] for n in accepts if n in terminals_by_name + ] + lexer = PartialBasicLexer(lexer_conf) + lexer_by_symbols[key] = lexer + + self.lexers[state] = lexer + + assert trad_conf.terminals is terminals + self.root_lexer = PartialBasicLexer(trad_conf) + + def lex(self, lexer_state: LexerState, parser_state: Any) -> Iterator[Token]: + try: + while True: + lexer = self.lexers[parser_state.position] + yield lexer.next_token(lexer_state, parser_state) + except EOFError: + pass + + +class PartialBasicLexer(BasicLexer): + def __init__(self, conf: "LexerConf"): + super().__init__(conf) + # Eagerly construct the scanner + self._build_scanner() + + def _build_scanner(self): + # This seems incredibly convoluted: `lark` creates callback-triggered + # nested scanners for regex-defined terminals that overlap with + # string-defined terminals when both types of terminals have the same + # priority. Unless I'm missing something important, why not simply + # reorder the terminals so that the string-defined ones come before the + # regex-defined ones? + terminals, self.callback = _create_unless( + self.terminals, self.g_regex_flags, self.re, self.use_bytes + ) + + # We can't let people arbitrarily mess with the scanning process. + assert not self.user_callbacks + # for type_, f in self.user_callbacks.items(): + # if type_ in self.callback: + # # Already a callback there, probably UnlessCallback + # self.callback[type_] = CallChain( + # self.callback[type_], f, lambda t: t.type == type_ + # ) + # else: + # self.callback[type_] = f + + # We used the "callback" results to reorder the terminals (see the + # comments above). + for terminal_name, callback in self.callback.items(): + terminal = self.terminals_by_name[terminal_name] + for sub_terminal in callback.scanner.terminals: + self.terminals.remove(sub_terminal) + idx = self.terminals.index(terminal) + self.terminals.insert(idx, sub_terminal) + + self._scanner = PartialScanner( + self.terminals, self.g_regex_flags, self.re, self.use_bytes + ) + + def match(self, text, pos, last_fsm_state_seq=None): + return self.scanner.match(text, pos, last_fsm_state_seq) + + def next_token(self, lex_state: LexerState, parser_state: Any = None) -> Token: + last_token = lex_state.last_token + + last_fsm_state_seq = None + if last_token and last_token.type == "partial": + # Continue from last partial lexer state + last_fsm_state_seq = last_token.value.fsm_state_seq + + line_ctr = lex_state.line_ctr + end_pos = line_ctr.char_pos + ( + len(last_fsm_state_seq) - 1 if last_fsm_state_seq else 0 + ) + while end_pos < len(lex_state.text): + res = self.match(lex_state.text, line_ctr.char_pos, last_fsm_state_seq) + + if not res: + if ( + not last_fsm_state_seq + or last_fsm_state_seq[-1] not in self.scanner.fsm.finals + ): + allowed = self.scanner.allowed_types - self.ignore_types + if not allowed: + allowed = {""} + raise UnexpectedCharacters( + lex_state.text, + line_ctr.char_pos, + line_ctr.line, + line_ctr.column, + allowed=allowed, + token_history=lex_state.last_token and [lex_state.last_token], + state=parser_state, + terminals_by_name=self.terminals_by_name, + ) + + # The partial match might be complete now + fsm_state_seq = last_token.value.fsm_state_seq + terminals_and_info = last_token.value.terminals_and_info + final_terminals_and_info = last_token.value.final_terminals_and_info + else: + fsm_state_seq = res + ( + terminals_and_info, + final_terminals_and_info, + ) = self.scanner.get_terminals_info(fsm_state_seq) + + priority_terminal_info = ( + final_terminals_and_info[0] + if final_terminals_and_info + else terminals_and_info[0] + ) + + is_not_finished = ( + not priority_terminal_info.is_final + or priority_terminal_info.can_transition + or len(terminals_and_info) > 1 + ) + + start_pos = line_ctr.char_pos + end_pos = start_pos + len(fsm_state_seq) - 1 + + if end_pos >= len(lex_state.text) and is_not_finished: + type_name = "partial" + token_value = PartialTokensInfo( + fsm_state_seq, + is_not_finished, + terminals_and_info, + final_terminals_and_info, + ) + # Don't update the line counter states until we've finished + value = "" + else: + type_name = priority_terminal_info.terminal_name + # The token value should contain all partial scan parts in this + # case + value = token_value = lex_state.text[start_pos:end_pos] + + assert isinstance(self.callback, Dict) + + if type_name not in self.ignore_types: + t = Token( + type_name, + token_value, + line_ctr.char_pos, + line_ctr.line, + line_ctr.column, + ) + + line_ctr.feed(value, type_name in self.newline_types) + + t.end_line = line_ctr.line + t.end_column = line_ctr.column + t.end_pos = line_ctr.char_pos + if t.type in self.callback: + t = self.callback[t.type](t) + if not isinstance(t, Token): + raise LexError( + "Callbacks must return a token (returned %r)" % t + ) + lex_state.last_token = t + return t + + if type_name in self.callback: + t2 = Token( + type_name, value, line_ctr.char_pos, line_ctr.line, line_ctr.column + ) + self.callback[type_name](t2) + + line_ctr.feed(value, type_name in self.newline_types) + + last_fsm_state_seq = None + + raise EOFError(self) + + +class PartialIndenter(Indenter): + """An `Indenter` that doesn't reset its state every time `process` is called.""" + + def process(self, stream): + return self._process(stream) + + def _process(self, stream): + for token in stream: + # These were previously *after* the `yield`, but that makes the + # state tracking unnecessarily convoluted. + if token.type in self.OPEN_PAREN_types: + self.paren_level += 1 + elif token.type in self.CLOSE_PAREN_types: + self.paren_level -= 1 + if self.paren_level < 0: + raise UnexpectedToken(token, []) + + if token.type == self.NL_type: + yield from self.handle_NL(token) + else: + yield token + + # TODO: What do we want to do here? + # while len(self.indent_level) > 1: + # self.indent_level.pop() + # yield Token(self.DEDENT_type, "") + + def accepts_token_type(self, token_type): + if token_type in self.CLOSE_PAREN_types and self.paren_level - 1 < 0: + return False + + # TODO: + # if token_type == self.NL_type and self.paren_level == 0: + # ... + # return False + + return True + + def __copy__(self): + res = type(self)() + res.paren_level = self.paren_level + res.indent_level = copy(self.indent_level) + return res + + def __repr__(self): + return f"{type(self).__name__}(paren_level={self.paren_level!r}, indent_level={self.indent_level!r})" + + +class PartialPythonIndenter(PartialIndenter): + NL_type = "_NEWLINE" + OPEN_PAREN_types = ["LPAR", "LSQB", "LBRACE"] + CLOSE_PAREN_types = ["RPAR", "RSQB", "RBRACE"] + INDENT_type = "_INDENT" + DEDENT_type = "_DEDENT" + tab_len = 8 + + +def get_contextual_lexer(x: Union[PartialLexerThread, PartialParsingFrontend]): + if isinstance(x.lexer, ContextualLexer): + return x.lexer + else: + return x.lexer.lexer + + +def terminals_to_fsms(lp: PartialLark) -> Dict[str, FSM]: + """Construct a ``dict`` mapping terminal symbol names to their finite state machines.""" + + symbol_names_and_fsms = {} + for terminal in lp.terminals: + pattern = interegular.parse_pattern(terminal.pattern.to_regexp()) + # TODO: Use `pyparser.terminals[0].pattern.flags`? + try: + fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce()) + except Unsupported: + fsm = None + + symbol_names_and_fsms[terminal.name] = fsm + + return symbol_names_and_fsms diff --git a/build/lib/outlines/fsm/regex.py b/build/lib/outlines/fsm/regex.py new file mode 100644 index 000000000..644a59186 --- /dev/null +++ b/build/lib/outlines/fsm/regex.py @@ -0,0 +1,922 @@ +import re +from collections import namedtuple +from functools import lru_cache +from typing import ( + TYPE_CHECKING, + Dict, + FrozenSet, + Generator, + List, + Sequence, + Set, + Tuple, + Union, + cast, +) + +import numba +import numpy as np +from interegular.fsm import ( + FSM, + Alphabet, + OblivionError, + State, + TransitionKey, + _AnythingElseCls, + anything_else, +) +from numba.typed.typedobjectutils import _nonoptional +from tqdm import tqdm + +from outlines.fsm.vocab_trie import VocabTrie + +if TYPE_CHECKING: + from outlines.models.tokenizer import Tokenizer + + +class BetterAlphabet(Alphabet): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert anything_else in self._symbol_mapping + self.anything_value = self._symbol_mapping[anything_else] + + def __getitem__(self, item): + return self._symbol_mapping.get(item, self.anything_value) + + def copy(self): + return BetterAlphabet(self._symbol_mapping.copy()) + + +class BetterFSM(FSM): + flat_transition_map: Dict[Tuple[int, int], int] + trans_key_to_states: Dict[int, List[int]] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if not isinstance(self.alphabet, BetterAlphabet): + self.__dict__["alphabet"] = BetterAlphabet(self.alphabet._symbol_mapping) + + flat_transition_map = {} + trans_key_to_states = {} + for from_state, trans_map in self.map.items(): + for trans_key, to_state in trans_map.items(): + flat_transition_map[(from_state, trans_key)] = to_state + trans_key_to_states.setdefault(trans_key, set()).add(from_state) + + self.__dict__["trans_key_to_states"] = trans_key_to_states + self.__dict__["flat_transition_map"] = flat_transition_map + self.__dict__["_fsm_info"] = None + + def copy(self): + return BetterFSM( + alphabet=self.alphabet.copy(), + states=self.states.copy(), + initial=self.initial, + finals=self.finals.copy(), + map=self.map.copy(), + __no_validation__=True, + ) + + @property + def fsm_info(self): + if self._fsm_info is None: + flat_transition_map_items = np.fromiter( + ((a[0], a[1], b) for a, b in self.flat_transition_map.items()), + dtype=np.dtype("int64, int64, int64"), + ) + trans_key_to_states_items = np.fromiter( + ((k, z) for k, v in self.trans_key_to_states.items() for z in v), + dtype=np.dtype("int64, int64"), + ) + alphabet_symbol_mapping_items = [ + (k, v) + for k, v in self.alphabet._symbol_mapping.items() + if k != anything_else + ] + nb_finals = np.fromiter(self.finals, dtype=np.dtype("int64")) + self.__dict__["_fsm_info"] = create_fsm_info( + self.initial, + nb_finals, + flat_transition_map_items, + trans_key_to_states_items, + self.alphabet.anything_value, + alphabet_symbol_mapping_items, + ) + + return self._fsm_info + + +nb_int_list_type = numba.types.ListType(numba.int64) +nb_int_pair_type = numba.types.UniTuple(numba.int64, 2) +nb_unicode_type = numba.types.unicode_type + + +@numba.njit(cache=True) +def create_fsm_info( + py_initial, + py_finals, + flat_transition_map_items, + trans_key_to_states_items, + py_anything_value, + alphabet_symbol_mapping_items, +): + trans_key_to_states = numba.typed.Dict.empty(numba.int64, nb_int_list_type) + for trans_key_and_state in trans_key_to_states_items: + trans_key_to_states.setdefault( + trans_key_and_state[0], numba.typed.List.empty_list(numba.int64) + ).append(trans_key_and_state[1]) + + flat_transition_map = numba.typed.Dict.empty(nb_int_pair_type, numba.int64) + for trans_key_and_state in flat_transition_map_items: + flat_transition_map[ + (trans_key_and_state[0], trans_key_and_state[1]) + ] = trans_key_and_state[2] + + # use 2-char strings so that we can represent incomplete utf-8 sequences + # as 2-hex-digit pairs + alphabet_symbol_map = numba.typed.Dict.empty(nb_unicode_type, numba.int64) + for symbol_and_trans_key in alphabet_symbol_mapping_items: + alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1] + + initial = numba.int64(py_initial) + + finals = set() + for final in py_finals: + finals.add(final) + + anything_value = numba.int64(py_anything_value) + + return FSMInfo( + initial, + finals, + flat_transition_map, + trans_key_to_states, + anything_value, + alphabet_symbol_map, + ) + + +FSMInfo = namedtuple( + "FSMInfo", + [ + "initial", + "finals", + "transitions", + "trans_key_to_states", + "alphabet_anything_value", + "alphabet_symbol_mapping", + ], +) + + +TransitionTrie = Dict[TransitionKey, "Union[TransitionTrie, State, None]"] + + +def add_to_transition_trie( + trie: TransitionTrie, + key_seq: Sequence[TransitionKey], + value: Union[State, None], +): + for key in key_seq[:-1]: + trie = cast(TransitionTrie, trie.setdefault(key, {})) + assert isinstance(trie, dict), "key sequence of incompatible length" + trie[key_seq[-1]] = value + + +# merge default_trie into the trie, only updating entries not present in the trie +def transition_trie_setdefault( + trie: TransitionTrie, + default_trie: TransitionTrie, +): + for key, default_value in default_trie.items(): + dest_value = trie.get(key) + if isinstance(dest_value, dict) and isinstance(default_value, dict): + transition_trie_setdefault(dest_value, default_value) + elif key not in trie: + trie[key] = default_value + + +def byte_symbol(byte: int) -> str: + return f"\x00{byte:02X}" if byte >= 0x80 else chr(byte) + + +def make_byte_level_fsm(fsm: FSM, keep_utf8=False) -> FSM: + """Convert an FSM to a byte-level FSM, expanding multi-byte characters as + sequences of single-byte transitions. If keep_utf8 is set, the original + utf-8 characters are kept in the alphabet. + NOTE: we're representing bytes as strings to keep it type-compatible. + """ + + anything_else_key = fsm.alphabet[anything_else] + symbol_mapping: Dict[Union[str, _AnythingElseCls], TransitionKey] = {} + map: Dict[State, Dict[TransitionKey, State]] = {} + states: List[State] = list(fsm.states) + + # identify all multi-byte characters in the alphabet and build a mapping + # from the original transition keys to sequences of new keys for each byte + key_to_key_seqs: Dict[TransitionKey, Set[Tuple[TransitionKey, ...]]] = {} + all_key_seqs: Set[Tuple[TransitionKey, ...]] = set() + all_bytes: Set[int] = set() + max_key = max(fsm.alphabet.values()) + for symbol, transition_key in fsm.alphabet.items(): + assert symbol == anything_else or len(symbol) == 1 + if symbol == anything_else or ord(symbol) < 0x80: + symbol_mapping[symbol] = transition_key + else: + if keep_utf8: + symbol_mapping[symbol] = transition_key + key_list: List[TransitionKey] = [] + for byte in symbol.encode("utf-8"): + symbol = byte_symbol(byte) + if symbol not in symbol_mapping: + symbol_mapping[symbol] = max_key = TransitionKey(max_key + 1) + all_bytes.add(byte) + key_list.append(symbol_mapping[symbol]) + key_seq = tuple(key_list) + key_to_key_seqs.setdefault(transition_key, set()).add(key_seq) + all_key_seqs.add(key_seq) + + # add all remaining multi-byte utf-8 bytes to the alphabet + # (this is required to represent `anything_else`) + utf8_ranges = { + 1: (0x80, 0xC0), # continuation bytes + 2: (0xC0, 0xE0), # 2-byte sequences + 3: (0xE0, 0xF0), # 3-byte sequences + 4: (0xF0, 0xF8), # 4-byte sequences + } + utf8_all_keys: Dict[int, Set[TransitionKey]] = { + n: set() for n in utf8_ranges.keys() + } + for n, (start, end) in utf8_ranges.items(): + range_key = max_key = TransitionKey(max_key + 1) + for byte in range(start, end): + byte_key = symbol_mapping.setdefault(byte_symbol(byte), range_key) + utf8_all_keys[n].add(byte_key) + + # cache of intermediate transition states by transitions from that state + state_cache: Dict[FrozenSet[Tuple[TransitionKey, State]], State] = {} + + # helper function to create multi-step transitions between states + max_state = max(fsm.states) + + def create_seq_transitions( + seq_transitions_trie: TransitionTrie, + ) -> Dict[TransitionKey, State]: + nonlocal max_state + result: Dict[TransitionKey, State] = {} + + for next_key, next_trie in seq_transitions_trie.items(): + if isinstance(next_trie, dict): + next_transitions = create_seq_transitions(next_trie) + if not next_transitions: + continue + cache_key = frozenset(next_transitions.items()) + next_state = state_cache.get(cache_key) + if next_state is None: + next_state = max_state = State(max_state + 1) + map[next_state] = next_transitions + state_cache[cache_key] = next_state + states.append(next_state) + result[next_key] = next_state + elif next_trie is not None: + result[next_key] = next_trie + + return result + + # create new states and transitions + for state, transitions in fsm.map.items(): + seq_transitions_trie: TransitionTrie = {} + state_map: Dict[TransitionKey, State] = {} + + for transition_key, to_state in transitions.items(): + if transition_key in key_to_key_seqs: + if keep_utf8: + state_map[transition_key] = to_state + for key_seq in key_to_key_seqs[transition_key]: + add_to_transition_trie(seq_transitions_trie, key_seq, to_state) + else: # keep single-byte transitions as is + state_map[transition_key] = to_state + + # handle multi-byte anything_else sequences + if anything_else_key in transitions: + for key_seq in all_key_seqs: + add_to_transition_trie(seq_transitions_trie, key_seq, None) + + anything_else_trie: TransitionTrie = {} + cont_trie: Union[TransitionTrie, State] = transitions[anything_else_key] + for n in range(2, 5): + cont_trie = {key: cont_trie for key in utf8_all_keys[1]} + for key in utf8_all_keys[n]: + anything_else_trie[key] = cont_trie + + transition_trie_setdefault(seq_transitions_trie, anything_else_trie) + + # create new states and transitions + next_transitions = create_seq_transitions(seq_transitions_trie) + state_map.update(next_transitions) + map[state] = state_map + + return FSM( + alphabet=Alphabet(symbol_mapping), + states=states, + initial=fsm.initial, + finals=fsm.finals, + map=map, + ) + + +def make_byte_level_better_fsm(fsm: BetterFSM, keep_utf8=False) -> BetterFSM: + new_fsm = make_byte_level_fsm(fsm, keep_utf8) + return BetterFSM( + alphabet=BetterAlphabet(new_fsm.alphabet._symbol_mapping), + states=new_fsm.states, + initial=new_fsm.initial, + finals=new_fsm.finals, + map=new_fsm.map, + ) + + +def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]: + """Construct an equivalent FSM with deterministic state labels.""" + old_to_new_trans_keys = { + trans_key: i + for i, (trans_key, _) in enumerate( + sorted(fsm.alphabet.by_transition.items(), key=lambda x: sorted(x[1])) + ) + } + + new_symbol_mapping = { + symbol: old_to_new_trans_keys[trans_key] + for symbol, trans_key in fsm.alphabet._symbol_mapping.items() + } + + new_alphabet = BetterAlphabet(new_symbol_mapping) + + new_map = { + from_state: { + old_to_new_trans_keys[trans_key]: to_state + for trans_key, to_state in trans_map.items() + } + for from_state, trans_map in fsm.map.items() + } + + old_to_new_states = {} + old_to_new_states[fsm.initial] = 0 + + i = 0 + seen = {fsm.initial} + old_state_queue = [fsm.initial] + while old_state_queue: + old_state = old_state_queue.pop(-1) + transitions = new_map[old_state] + sorted_transitions = sorted(transitions.items(), key=lambda v: v[0]) + for _, old_state in sorted_transitions: + if old_state not in seen: + old_state_queue.append(old_state) + seen.add(old_state) + if old_state not in old_to_new_states: + i += 1 + old_to_new_states[old_state] = i + + new_map = dict( + sorted( + ( + ( + old_to_new_states[from_state], + dict( + sorted( + ( + (trans_key, old_to_new_states[to_state]) + for trans_key, to_state in trans_map.items() + ), + key=lambda v: v[0], + ) + ), + ) + for from_state, trans_map in new_map.items() + ), + key=lambda v: v[0], + ) + ) + + new_initial = 0 + new_finals = frozenset( + sorted(old_to_new_states[old_state] for old_state in fsm.finals) + ) + new_states = frozenset(sorted(new_map.keys())) + + new_fsm = BetterFSM(new_alphabet, new_states, new_initial, new_finals, new_map) + + return new_fsm, old_to_new_states + + +@numba.njit(nogil=True, cache=True) +def _walk_fsm( + fsm_transitions: Dict[Tuple[int, int], int], + fsm_initial: int, + fsm_finals: Set[int], + token_transition_keys: Sequence[int], + start_state: int, + full_match: bool = True, +) -> List[int]: + state = start_state + accepted_states: List[int] = numba.typed.List.empty_list(numba.int64) + last_final_idx: int = numba.uint64(0) + + # Iterate over token transition key sequence. The transition key + # sequence represents the FSM traversal rules of the tokens symbols. + for i, trans_key in enumerate(token_transition_keys): + new_state = fsm_transitions.get((state, trans_key)) + + if new_state is None: + if not full_match and last_final_idx > 0: + return accepted_states[:last_final_idx] + + return numba.typed.List.empty_list(numba.int64) + + state = new_state + + if state in fsm_finals: + last_final_idx = numba.uint64(i + 1) + + accepted_states.append(_nonoptional(state)) + + if full_match and last_final_idx - 1 != i: + return numba.typed.List.empty_list(numba.int64) + + return accepted_states + + +def walk_fsm( + fsm: BetterFSM, + token_transition_keys: Sequence[int], + start_state: int, + full_match: bool = True, +) -> List[int]: + fsm_finals = fsm.finals + + state = start_state + accepted_states: List[int] = [] + last_final_idx: int = 0 + + fsm_transitions = fsm.flat_transition_map + + # Iterate over token transition key sequence. The transition key + # sequence represents the FSM traversal rules of the tokens symbols. + for i, trans_key in enumerate(token_transition_keys): + new_state = fsm_transitions.get((state, trans_key)) + + if new_state is None: + if not full_match and last_final_idx > 0: + return accepted_states[:last_final_idx] + + return [] + + state = new_state + + if state in fsm_finals: + last_final_idx = i + 1 + + accepted_states.append(state) + + if full_match and last_final_idx - 1 != i: + return [] + + return accepted_states + + +def fsm_union( + fsms: Sequence[FSM], +) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]: + """Construct an FSM representing the union of the FSMs in `fsms`. + + This is an updated version of `interegular.fsm.FSM.union` made to return an + extra map of component FSMs to the sets of state transitions that + correspond to them in the new FSM. + + """ + + alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms]) + + indexed_fsms = tuple(enumerate(fsms)) + + initial = {i: fsm.initial for (i, fsm) in indexed_fsms} + + # Dedicated function accepting a "superset" and returning the next + # "superset" obtained by following this transition in the new FSM + def follow(current_state, new_transition: int): + next = {} + for i, f in indexed_fsms: + old_transition = new_to_old[i][new_transition] + if ( + i in current_state + and current_state[i] in f.map + and old_transition in f.map[current_state[i]] + ): + next[i] = f.map[current_state[i]][old_transition] + if not next: + raise OblivionError + return next + + states = [initial] + finals: Set[int] = set() + map: Dict[int, Dict[int, int]] = {} + + # Map component FSMs to their new state-to-state transitions, finals, and a + # map translating component FSM states to aggregate FSM states + fsms_to_trans_finals: Dict[ + int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]] + ] = {} + + i = 0 + while i < len(states): + state = states[i] + + # Add to the finals of the aggregate FSM whenever we hit a final in a + # component FSM + if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms): + finals.add(i) + + # Compute the map for this state + map[i] = {} + for transition in alphabet.by_transition: + try: + next = follow(state, transition) + except OblivionError: + # Reached an oblivion state; don't list it + continue + else: + try: + # TODO: Seems like this could--and should--be avoided + j = states.index(next) + except ValueError: + j = len(states) + states.append(next) + + map[i][transition] = j + + for fsm_id, fsm_state in next.items(): + ( + fsm_transitions, + fsm_finals, + fsm_old_to_new, + ) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {})) + old_from = state[fsm_id] + old_to = fsm_state + fsm_old_to_new.setdefault(old_from, set()).add(i) + fsm_old_to_new.setdefault(old_to, set()).add(j) + fsm_transitions.add((i, j)) + if fsm_state in fsms[fsm_id].finals: + fsm_finals.add(j) + + i += 1 + + fsm = FSM( + alphabet=alphabet, + states=range(len(states)), + initial=0, + finals=finals, + map=map, + __no_validation__=True, + ) + + fsm, old_to_new_states = make_deterministic_fsm(fsm) + _fsms_to_trans_finals = { + fsm_id: ( + {(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions}, + {old_to_new_states[s] for s in finals}, + { + old_state: {old_to_new_states[new_state] for new_state in new_states} + for old_state, new_states in old_to_new.items() + }, + ) + for fsm_id, (transitions, finals, old_to_new) in sorted( + fsms_to_trans_finals.items(), key=lambda x: x[0] + ) + } + + return ( + fsm, + _fsms_to_trans_finals, + ) + + +def get_sub_fsms_from_seq( + state_seq: Sequence[int], + fsms_to_trans_finals: Dict[ + int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]] + ], +) -> Generator[Tuple[int, bool, bool], None, None]: + """Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`. + + Parameters + ---------- + state_seq + A state sequence. + fsms_to_trans_finals + A map from FSM indices to tuples containing sets of their state transitions + and sets of the final/accept states. + + Returns + ------- + A generator returning tuples containing each sub-FSM index (in the order + they were union-ed to construct `fsm`) and booleans indicating whether or + not there is another valid transition from the last state in the sequence + for the associated sub-FSM (i.e. if the FSM can continue + accepting/matching) and whether or not the sequence ends in a final state + of the sub-FSM. + """ + state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:])) + last_fsm_state = state_seq[-1] + yield from ( + ( + # The sub-FMS index + fsm_idx, + # Is there another possible transition in this sub-FSM? + any(last_fsm_state == from_s for (from_s, to_s) in transitions), + # Is this sub-FSM in a final state? + state_seq[-1] in finals, + ) + for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items() + if state_seq_transitions.issubset(transitions) + ) + + +@numba.njit(cache=True, nogil=True) +def state_scan_tokens( + fsm_transitions: Dict[Tuple[int, int], int], + alphabet_symbol_mapping: Dict[str, int], + alphabet_anything_value: int, + fsm_initial: int, + fsm_finals: Set[int], + vocabulary: List[Tuple[str, Sequence[int]]], + vocab_trie: VocabTrie, + start_state: int, +) -> Set[Tuple[int, int]]: + res = set() + + # Initialize the stack with tokens having no prefixes + stack = numba.typed.List() + for token_transitions_seq in vocab_trie.get_children(): + stack.append(token_transitions_seq) + + # Process the tokens using the stack + while len(stack) > 0: + token_transition_seq = stack.pop() + state_seq = _walk_fsm( + fsm_transitions, + fsm_initial, + fsm_finals, + token_transition_seq, + start_state, + False, + ) + + if state_seq is not None and len(state_seq) < len(token_transition_seq): + continue + + for token_id in vocab_trie.get_token_ids(token_transition_seq): + res.add((token_id, state_seq[-1])) + + # Add successors to the stack + for new_token in vocab_trie.get_children(token_transition_seq): + stack.append(new_token) + + return res + + +@numba.njit(cache=True, nogil=True) +def get_token_transition_keys( + alphabet_symbol_mapping: Dict[str, int], + alphabet_anything_value: int, + token_str: str, +) -> Sequence[int]: + """ + Get the sequence of transition keys for an individual string + with respect to an FSMs alphabet symbol mapping + + This requires parsing the null-byte prefix rules of a byte-fsm: + - If two characters are prefixed by \x00, they are the grouped as a hex-byte + - Otherwise they are a standalone utf-8 character + """ + token_transition_keys = [] + i = 0 + while i < len(token_str): + if token_str[i] == "\x00" and i != len(token_str) - 1: + symbol = token_str[i : i + 3] + i += 3 + else: + symbol = token_str[i] + i += 1 + + token_transition_keys.append( + alphabet_symbol_mapping.get(symbol, alphabet_anything_value) + ) + + token_transition_keys_array = np.empty(len(token_transition_keys), dtype=np.int64) + for j in range(len(token_transition_keys)): + token_transition_keys_array[j] = token_transition_keys[j] + return token_transition_keys_array + + +@numba.njit(cache=True, nogil=True) +def get_vocabulary_transition_keys( + alphabet_symbol_mapping: Dict[str, int], + alphabet_anything_value: int, + vocabulary: List[Tuple[str, Sequence[int]]], +) -> List[Sequence[int]]: + """ + Calculate the sequence transition keys for each token str within a vocabulary + """ + vocab_transition_keys = numba.typed.List.empty_list(numba.int64[:]) + for token_str, _ in vocabulary: + token_transition_keys = get_token_transition_keys( + alphabet_symbol_mapping, alphabet_anything_value, token_str + ) + vocab_transition_keys.append(token_transition_keys) + + return vocab_transition_keys + + +def create_fsm_index_end_to_end( + fsm_info: FSMInfo, + vocabulary: List[Tuple[str, Sequence[int]]], +) -> Dict[int, Set[Tuple[int, int]]]: + """Create an FSM state-to-vocabulary map/index through end-to-end token parsing.""" + + # TODO: Consider using a `List` of `Set`s instead; that way we can JIT this + # code, too. + states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {} + seen: Set[int] = set() + next_states = {fsm_info.initial} + + vocabulary_transitions = get_vocabulary_transition_keys( + fsm_info.alphabet_symbol_mapping, + fsm_info.alphabet_anything_value, + vocabulary, + ) + vocab_trie = VocabTrie(vocabulary_transitions, vocabulary) + + pbar = tqdm( + total=len(set(fsm_info.transitions.values())) + + 1, # all transitions plus initial + desc="Compiling FSM index for all state transitions", + ) + + while next_states: + start_state = next_states.pop() + + token_ids_end_states = state_scan_tokens( + fsm_info.transitions, + fsm_info.alphabet_symbol_mapping, + fsm_info.alphabet_anything_value, + fsm_info.initial, + fsm_info.finals, + vocabulary, + vocab_trie, + start_state, + ) + + for token_id_and_end_state in token_ids_end_states: + states_to_token_subsets.setdefault(start_state, set()).add( + token_id_and_end_state + ) + end_state = token_id_and_end_state[1] + if end_state not in seen: + next_states.add(end_state) + + if start_state not in seen: + pbar.update(1) + seen.add(start_state) + + pbar.close() + + return states_to_token_subsets + + +re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$") +re_replacement_seq = re.compile(r"^▁*�+$") + + +# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode +@lru_cache() +def gpt2_bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. + + The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab + if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for + decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup + tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +@lru_cache() +def gpt2_unicode_to_bytes(): + return {v: k for k, v in gpt2_bytes_to_unicode().items()} + + +# TODO: Cannot cache typed collections to disk, yet. See +# https://github.com/numba/numba/issues/4698 +@lru_cache +def reduced_vocabulary( + tokenizer: "Tokenizer", +) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]: + """Create a map from decoded vocabulary tokens to lists of equivalent token ids.""" + empty_token_ids = set() + vocabulary: Dict[Union[str, Tuple[str, ...]], List[int]] = {} + for token, token_idx in tokenizer.vocabulary.items(): + if token in tokenizer.special_tokens: + continue + + token_str: Union[str, Tuple[str, ...]] = tokenizer.convert_token_to_string( + token + ) + + if token_str: + # invalid utf-8 sequences are replaced with � (\ufffd), but there + # might also be tokens specifically for �, ��, ���, etc. + if "\ufffd" in token_str and not re_replacement_seq.match(token): + if re_llama_byte_token.match(token): + # llama-like tokenizers have <0xXX> tokens for all + # bytes >= 0x80 and represent all incomplete utf-8 + # sequences using such tokens + token_bytes = [int(token[3:5], 16)] + else: + # gpt2-like tokenizers have multi-byte tokens that can + # have a mix of full and incomplete utf-8 characters, + # for example, b` \xf0` can be one token; these tokenizers + # map each byte to a valid utf-8 character + token_bytes = cast( + List[int], [gpt2_unicode_to_bytes().get(c) for c in token] + ) + if None in token_bytes: + raise RuntimeError( + f"Cannot convert token `{token}` ({token_idx}) to bytes: {token_str}" + ) + token_str = "".join(byte_symbol(b) for b in token_bytes) + + vocabulary.setdefault(token_str, []).append(token_idx) + else: + empty_token_ids.add(numba.int64(token_idx)) + + vocabulary_nb = numba.typed.List.empty_list( + numba.types.Tuple( + ( + nb_unicode_type, + numba.int64[:], + ) + ) + ) + for token_str, token_ids in vocabulary.items(): + token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) + vocabulary_nb.append((token_str, token_ids_np)) + + return vocabulary_nb, empty_token_ids + + +def create_fsm_index_tokenizer( + fsm: BetterFSM, + tokenizer: "Tokenizer", +) -> Tuple[Dict[int, Dict[int, int]], Set[int]]: + """Construct an FMS index from a tokenizer. + + This uses the end-to-end approach of `create_fsm_index_end_to_end`. + + .. warning:: + + `fsm` needs to be deterministically ordered so that future caching makes sense. + + """ + vocabulary, empty_token_ids = reduced_vocabulary(tokenizer) + + states_to_token_subsets = create_fsm_index_end_to_end(fsm.fsm_info, vocabulary) + + # Allow transitions to EOS from all terminals FSM states that are + # reachable + # TODO: Do we really need this anymore? + for state in fsm.fsm_info.finals: + subset = states_to_token_subsets.get(state) + if subset is not None: + subset.add((tokenizer.eos_token_id, state)) + + # Convert to token-to-end-state maps + states_to_token_subsets = {k: dict(v) for k, v in states_to_token_subsets.items()} + + return states_to_token_subsets, empty_token_ids diff --git a/build/lib/outlines/fsm/types.py b/build/lib/outlines/fsm/types.py new file mode 100644 index 000000000..5695dee07 --- /dev/null +++ b/build/lib/outlines/fsm/types.py @@ -0,0 +1,81 @@ +import datetime +from enum import EnumMeta +from typing import Any, Protocol, Tuple, Type + +from typing_extensions import _AnnotatedAlias, get_args + +INTEGER = r"[+-]?(0|[1-9][0-9]*)" +BOOLEAN = "(True|False)" +FLOAT = rf"{INTEGER}(\.[0-9]+)?([eE][+-][0-9]+)?" +DATE = r"(\d{4})-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])" +TIME = r"([0-1][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])" +DATETIME = rf"({DATE})(\s)({TIME})" + + +class FormatFunction(Protocol): + def __call__(self, sequence: str) -> Any: + ... + + +def python_types_to_regex(python_type: Type) -> Tuple[str, FormatFunction]: + # If it is a custom type + if isinstance(python_type, _AnnotatedAlias): + json_schema = get_args(python_type)[1].json_schema + type_class = get_args(python_type)[0] + + custom_regex_str = json_schema["pattern"] + + def custom_format_fn(sequence: str) -> Any: + return type_class(sequence) + + return custom_regex_str, custom_format_fn + + if isinstance(python_type, EnumMeta): + values = python_type.__members__.keys() + enum_regex_str: str = "(" + "|".join(values) + ")" + + def enum_format_fn(sequence: str) -> str: + return str(sequence) + + return enum_regex_str, enum_format_fn + + if python_type == float: + + def float_format_fn(sequence: str) -> float: + return float(sequence) + + return FLOAT, float_format_fn + elif python_type == int: + + def int_format_fn(sequence: str) -> int: + return int(sequence) + + return INTEGER, int_format_fn + elif python_type == bool: + + def bool_format_fn(sequence: str) -> bool: + return bool(sequence) + + return BOOLEAN, bool_format_fn + elif python_type == datetime.date: + + def date_format_fn(sequence: str) -> datetime.date: + return datetime.datetime.strptime(sequence, "%Y-%m-%d").date() + + return DATE, date_format_fn + elif python_type == datetime.time: + + def time_format_fn(sequence: str) -> datetime.time: + return datetime.datetime.strptime(sequence, "%H:%M:%S").time() + + return TIME, time_format_fn + elif python_type == datetime.datetime: + + def datetime_format_fn(sequence: str) -> datetime.datetime: + return datetime.datetime.strptime(sequence, "%Y-%m-%d %H:%M:%S") + + return DATETIME, datetime_format_fn + else: + raise NotImplementedError( + f"The Python type {python_type} is not supported. Please open an issue." + ) diff --git a/build/lib/outlines/fsm/vocab_trie.py b/build/lib/outlines/fsm/vocab_trie.py new file mode 100644 index 000000000..52d11b0cf --- /dev/null +++ b/build/lib/outlines/fsm/vocab_trie.py @@ -0,0 +1,241 @@ +import operator +from typing import List, Optional, Sequence, Tuple + +import numpy as np +from numba import njit, typed, types +from numba.cpython.hashing import ( + _Py_uhash_t, + _PyHASH_XXPRIME_1, + _PyHASH_XXPRIME_2, + _PyHASH_XXPRIME_5, + _PyHASH_XXROTATE, + process_return, +) +from numba.experimental import jitclass, structref +from numba.extending import overload +from numba.typed import Dict + +########################### +# Dict With Int[:] Key Impl +########################### + + +# Register type +@structref.register +class IntArrayDictType(types.StructRef): + """ + Represents a dictionary using int64[:] as keys, + intended for byte-level FSM representation with int64[:] transition. + """ + + def preprocess_fields(self, fields): + return tuple( + (name, typ.dtype if isinstance(typ, types.TypeRef) else typ) + for name, typ in fields + ) + + +class IntArrayDict(structref.StructRefProxy): + """Python proxy""" + + @property + def wrapped_dict(self): + return IntArrayDict_get_wrapped_dict(self) # noqa: F821 + + +structref.define_proxy(IntArrayDict, IntArrayDictType, ["wrapped_dict"]) + + +@njit +def hash_key(key): + """ + XXH64 Hash for int64[:] keys + adapted from https://github.com/numba/numba/blob/556545/numba/cpython/hashing.py + """ + acc = _PyHASH_XXPRIME_5 + for i in range(key.shape[0]): + x = key[i] + lane = hash(x) + if lane == _Py_uhash_t(-1): + return -1 + acc += lane * _PyHASH_XXPRIME_2 + acc = _PyHASH_XXROTATE(acc) + acc *= _PyHASH_XXPRIME_1 + + acc += key.shape[0] ^ (_PyHASH_XXPRIME_5 ^ _Py_uhash_t(3527539)) + + if acc == _Py_uhash_t(-1): + return process_return(1546275796) + + return process_return(acc) + + +@overload(IntArrayDict) +def custom_int_array_dict_constructor(value_type): + if isinstance(value_type, types.Type): + + def impl(value_type): + wrapped_dictionary = Dict.empty(types.intp, value_type) + return IntArrayDict(wrapped_dictionary) + + return impl + + +@overload(operator.getitem) +def ol_int_array_dict_getitem(inst, key): + if isinstance(inst, IntArrayDictType): + + def impl(inst, key): + return inst.wrapped_dict[hash_key(key)] + + return impl + + +@overload(operator.setitem) +def ol_int_array_dict_setitem(inst, key, value): + if isinstance(inst, IntArrayDictType): + + def impl(inst, key, value): + inst.wrapped_dict[hash_key(key)] = value + + return impl + + +@overload(operator.contains) +def ol_int_array_dict_contains(inst, key): + if isinstance(inst, IntArrayDictType): + + def impl(inst, key): + return hash_key(key) in inst.wrapped_dict + + return impl + + +################# +# Vocab Trie Impl +################# + +nb_int64_array_type = types.int64[:] + +# use intp keys as that is the hash type, +# but the true key type is nb_int64_array_type +IntArrayToIntType = IntArrayDictType( + (("wrapped_dict", types.DictType(types.intp, types.int64)),) +) +IntArrayToIntArrayType = IntArrayDictType( + (("wrapped_dict", types.DictType(types.intp, nb_int64_array_type)),) +) + + +@jitclass( + [ + ("token_to_token_key", IntArrayToIntType), + ("token_key_to_token", types.DictType(types.int64, nb_int64_array_type)), + ( + "token_key_to_child_token_keys", + types.DictType(types.int64, nb_int64_array_type), + ), + ("token_to_token_ids", IntArrayToIntArrayType), + ], +) +class VocabTrie: + """ + VocabTrie: Class for efficient traversal of the vocabulary + Bidirectional mapping between trie node ID and nb_unichar_2_type token + - token_to_token_key: Dict[nb_unichar_2_array_type, int] + - token_key_to_token: Dict[int, nb_unichar_2_array_type] + Allow retrieval of children in trie + - token_key_to_child_token_keys: Dict[int, int64[:]] + Allow retrieval of of token_ids for a given token + - token_to_token_ids: Dict[nb_unichar_2_array_type, int64[:]] + Trie structure: + Only members of the vocabulary are included as nodes, no intermediates. + Structured to guarantee that recursive calls to get_children() + will return every token once, only once. + Given a vocabulary of ["a", "ab", "abc", "ac", "ace", "apple"], + the children of "a" are "ab", "ac", "apple". + "abc" and "ace" are excluded because they have intermediate parents in the vocabulary. + """ + + def __init__( + self, + all_token_transitions: List[Sequence[int]], + vocabulary: List[Tuple[str, Sequence[int]]], + ): + self.token_to_token_key = IntArrayDict( + typed.Dict.empty(types.intp, types.int64) + ) + self.token_key_to_token = typed.Dict.empty( + key_type=types.int64, value_type=nb_int64_array_type + ) + self.token_key_to_child_token_keys = typed.Dict.empty( + key_type=types.int64, value_type=nb_int64_array_type + ) + self.token_to_token_ids = IntArrayDict( + typed.Dict.empty(types.intp, nb_int64_array_type) + ) + + self._insert(all_token_transitions, vocabulary) + + def _insert( + self, + all_token_transitions: List[Sequence[int]], + vocabulary: List[Tuple[str, Sequence[int]]], + ) -> None: + # Initialize an empty array for the root token key to store child token keys + self.token_key_to_child_token_keys[-1] = np.empty((0,), types.int64) + + # It's necessary to insert shorter transition sequences (prefixes) first + sorted_idx_transition_seq = sorted( + enumerate(all_token_transitions), key=lambda x: len(x[1]) + ) + + for idx, token_transitions in sorted_idx_transition_seq: + token_ids = vocabulary[idx][1] + if token_transitions not in self.token_to_token_key: + # create bimapping between token and token_key (tokens trie node key) + self.token_to_token_key[token_transitions] = idx + self.token_key_to_token[idx] = token_transitions + + # find parent token key + parent_token_key = -1 # root token + for i in range(len(token_transitions) - 1, -1, -1): + prefix_token = token_transitions[:i] + + if prefix_token in self.token_to_token_key: + parent_token_key = self.token_to_token_key[prefix_token] + break + # map parent token to current token + self.token_key_to_child_token_keys[parent_token_key] = np.append( + self.token_key_to_child_token_keys[parent_token_key], + np.array([idx]), + ) + + # map current token to empty list of children + self.token_key_to_child_token_keys[idx] = np.empty((0,), types.int64) + + # set current tokens token ids + self.token_to_token_ids[token_transitions] = token_ids + + else: + # if exists, append to current tokens token ids + self.token_to_token_ids[token_transitions] = np.append( + self.token_to_token_ids[token_transitions], token_ids + ) + + def get_children(self, token_transitions: Optional[Sequence[int]] = None): + """ + Get the token_ids of all children for the given token_id. + If token_id is None, get the root children. + """ + if token_transitions is None: + token_key = -1 + else: + token_key = self.token_to_token_key[token_transitions] + + child_token_keys = self.token_key_to_child_token_keys[token_key] + + return [self.token_key_to_token[token_key] for token_key in child_token_keys] + + def get_token_ids(self, token): + return self.token_to_token_ids[token] diff --git a/build/lib/outlines/function.py b/build/lib/outlines/function.py new file mode 100644 index 000000000..48577be8f --- /dev/null +++ b/build/lib/outlines/function.py @@ -0,0 +1,117 @@ +import importlib.util +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union + +import requests + +from outlines import generate, models + +if TYPE_CHECKING: + from outlines.generate.api import SequenceGenerator + from outlines.prompts import Prompt + + +@dataclass +class Function: + """Represents an Outlines function. + + Functions are a convenient way to encapsulate a prompt template, a language + model and a Pydantic model that define the output structure. Once defined, + the function can be called with arguments that will be used to render the + prompt template. + + """ + + prompt_template: "Prompt" + schema: Union[str, Callable, object] + model_name: str + generator: Optional["SequenceGenerator"] = None + + @classmethod + def from_github(cls, program_path: str, function_name: str = "fn"): + """Load a function stored on GitHub""" + program_content = download_from_github(program_path) + function = extract_function_from_file(program_content, function_name) + + return function + + def init_generator(self): + """Load the model and initialize the generator.""" + model = models.transformers(self.model_name) + self.generator = generate.json(model, self.schema) + + def __call__(self, *args, **kwargs): + """Call the function. + + .. warning:: + + This currently does not support batching. + + Parameters + ---------- + args + Values to pass to the prompt template as positional arguments. + kwargs + Values to pass to the prompt template as keyword arguments. + + """ + if self.generator is None: + self.init_generator() + + prompt = self.prompt_template(*args, **kwargs) + return self.generator(prompt) + + +def download_from_github(short_path: str): + """Download the file in which the function is stored on GitHub.""" + GITHUB_BASE_URL = "https://github.com/raw" + BRANCH = "main" + + path = short_path.split("/") + if len(path) < 3: + raise ValueError( + "Please provide a valid path in the form {USERNAME}/{REPO_NAME}/{PATH_TO_FILE}." + ) + elif short_path[-3:] == ".py": + raise ValueError("Do not append the `.py` extension to the program name.") + + username = path[0] + repo = path[1] + path_to_file = path[2:] + + url = "/".join([GITHUB_BASE_URL, username, repo, BRANCH] + path_to_file) + ".py" + result = requests.get(url) + + if result.status_code == 200: + return result.text + elif result.status_code == 404: + raise ValueError( + f"Program could not be found at {url}. Please make sure you entered the GitHub username, repository name and path to the program correctly." + ) + else: + result.raise_for_status() + + +def extract_function_from_file(content: str, function_name: str) -> Tuple[Callable]: + """Extract a function object from a downloaded file.""" + + spec = importlib.util.spec_from_loader( + "outlines_function", loader=None, origin="github" + ) + if spec is not None: + module = importlib.util.module_from_spec(spec) + exec(content, module.__dict__) + + try: + fn = getattr(module, function_name) + except AttributeError: + raise AttributeError( + "Could not find an `outlines.Function` instance in the remote file. Make sure that the path you specified is correct." + ) + + if not isinstance(fn, module.outlines.Function): + raise TypeError( + f"The `{function_name}` variable in the program must be an instance of `outlines.Function`" + ) + + return fn diff --git a/build/lib/outlines/generate/__init__.py b/build/lib/outlines/generate/__init__.py new file mode 100644 index 000000000..f28cbd80d --- /dev/null +++ b/build/lib/outlines/generate/__init__.py @@ -0,0 +1,8 @@ +from .api import SequenceGenerator +from .cfg import cfg +from .choice import choice +from .format import format +from .fsm import fsm +from .json import json +from .regex import regex +from .text import text diff --git a/build/lib/outlines/generate/api.py b/build/lib/outlines/generate/api.py new file mode 100644 index 000000000..51a995664 --- /dev/null +++ b/build/lib/outlines/generate/api.py @@ -0,0 +1,531 @@ +import datetime +from dataclasses import dataclass +from typing import TYPE_CHECKING, Iterator, List, Optional, Union + +from outlines.generate.generator import sequence_generator +from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler + +if TYPE_CHECKING: + import torch + +FormattedOutput = Union[ + str, int, float, bool, datetime.date, datetime.time, datetime.datetime +] + + +class SequenceGenerator: + def __init__( + self, + fsm, + model, + sampler, + device, + ): + self.fsm = fsm + self.model = model + self.sampler = sampler + self.tokenizer = model.tokenizer + self.device = device + self.num_samples = sampler.samples + + def get_generated_token_ids( + self, + prompt_token_ids: "torch.Tensor", + token_ids: "torch.Tensor", + ) -> List["torch.Tensor"]: + """Get the tokens generated so far. + + Parameters + ---------- + prompt_token_ids + Tensor that contains the token ids of the sequences' prompts. + token_ids + The generated token ids. + + Returns + ------- + A tensor that contains the token ids that have been generated so far. + + """ + prompt_lengths = [len(prompt) for prompt in prompt_token_ids] + token_ids = [ + cur_token_ids[length:] + for cur_token_ids, length in zip(token_ids, prompt_lengths) + ] + + return token_ids + + def is_stop_sequence_found( + self, generated_sequences: List[str], stop_sequences: List[str] + ) -> bool: + """Determine whether one of the stop sequences has been generated. + + Parameters + ---------- + generated_sequences + The list of sequences generated so far. + stop_sequences + The list that contains the sequence which stop the generation when + found. + + Returns + ------- + True if at least one of the stop sequences has been found in each generated + sequence. + + """ + return all( + [ + any([seq in generated for seq in stop_sequences]) + for generated in generated_sequences + ] + ) + + def strip_stop_sequences( + self, sequence: str, stop_sequences: Optional[List[str]] + ) -> str: + """Remove the stop sequences from the generated sequences. + + Parameters + ---------- + sequence + One of the generated sequences. + stop_sequences + The list that contains the sequence which stop the generation when + found. + + """ + if stop_sequences: + match_indexes = [sequence.find(seq) for seq in stop_sequences] + if any([index != -1 for index in match_indexes]): + # select the stop_sequence that is found first in the sequence + min_match_index_value = min([i for i in match_indexes if i != -1]) + min_match_index_pos = match_indexes.index(min_match_index_value) + sequence = sequence[ + : match_indexes[min_match_index_pos] + + len(stop_sequences[min_match_index_pos]) + ] + + return sequence + + def format_sequence(self, sequence: str) -> FormattedOutput: + """Translate the generated sequence to another type. + + This method is for instance overridden when generating JSON to either + return a dictionnary or a Pydantic model. + + Parameters + ---------- + sequence + A generated sequences. + + Returns + ------- + The formatted sequence. + + """ + return sequence + + def __call__( + self, + prompts: Union[str, List[str]], + max_tokens: Optional[int] = None, + stop_at: Optional[Union[str, List[str]]] = None, + rng: Optional["torch.Generator"] = None, + ) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]: + """Generate the full text sequence. + + Since `SequenceGenerator.stream` calls the tokenizer at every step this + method loops over the generator returned by `sequence_generator` itself + so the tokenizer is called only once after all token ids have been + generated. + + Parameters + ---------- + prompts + A string or list of strings that are passed to the model before + generating the first token. + max_tokens + An integer representing maximum number of tokens that will be generated + (per prompt) + stop_at + A string or list of strings at which the text generated will stop + rng + The random number generator. Defaults to a non-seeded `torch.Generator` + instance. + + Returns + ------- + The generation(s), potentially cast to another type. + """ + import torch + + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(stop_at, str): + stop_at = [stop_at] + + stop_sequences = stop_at + num_samples = self.num_samples + + if rng is None: + rng = torch.Generator(device=self.device) + rng.seed() + + prompt_token_ids, attention_masks = self.tokenizer.encode(prompts) + prompt_token_ids = prompt_token_ids.to(self.device) + attention_masks = attention_masks.to(self.device) + + # To draw multiple samples we repeat the prompt as many times + # as there are samples. We copy the FSMs and initialize the + # FSM states. + num_samples = self.num_samples + batch_size = len(prompts) + + prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0) + attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0) + fsm_states = [0 for _ in range(batch_size * num_samples)] + fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)] + weights = torch.zeros( + (batch_size * num_samples), dtype=torch.float, device=self.device + ) + + states = sequence_generator( + self.model, + self.sampler, + fsms, + prompt_token_ids, + weights, + attention_masks, + fsm_states, + rng=rng, + ) + + while True: + try: + last_state = next(states) + if max_tokens or stop_sequences: + token_ids = last_state.token_ids + generated_token_ids = self.get_generated_token_ids( + prompt_token_ids, token_ids + ) + if max_tokens and len(generated_token_ids[0]) >= max_tokens: + break + if stop_sequences and self.is_stop_sequence_found( + self.tokenizer.decode(generated_token_ids), stop_sequences + ): + break + except StopIteration: + break + + token_ids = last_state.token_ids + generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids) + + generated = self.tokenizer.decode(generated_token_ids) + stripped = [ + self.strip_stop_sequences(sequence, stop_sequences) + for sequence in generated + ] + formatted = [self.format_sequence(sequence) for sequence in stripped] + + # We reshape the output to (batch_size, sample_size) + output: List[List[FormattedOutput]] = list() + for i in range(batch_size): + output.append(formatted[i : i + num_samples]) + + # We remove leading dimensions for the output + if batch_size == 1 and num_samples == 1: + return output[0][0] + elif batch_size == 1: + return output[0] + elif num_samples == 1: + return [samples[0] for samples in output] + else: + return output + + def stream( + self, + prompts: Union[str, List[str]], + max_tokens: Optional[int] = None, + stop_at: Optional[Union[str, List[str]]] = None, + rng: Optional["torch.Generator"] = None, + ) -> Iterator[Union[List[str], str, List[List[str]]]]: + """Generate the text sequence one token at a time. + + Since `Tokenizer.decode` strips the whitespaces from the tokens we have no + choice but to decode the generated token ids at each step and compare the + current decoded strings to the previously decoded strings. + + Parameters + ---------- + prompts + A string or list of strings that are passed to the model before + generating the first token. + max_tokens + An integer representing maximum number of tokens that will be generated + (per prompt) + stop_at + A string or list of strings at which the text generated will stop + rng + The random number generator. Defaults to a non-seeded `torch.Generator` + instance. + + Returns + ------- + A string or list of strings that contain the generated text. + + """ + import torch + + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(stop_at, str): + stop_at = [stop_at] + + stop_sequences = stop_at + num_samples = self.num_samples + + prompt_token_ids, attention_masks = self.tokenizer.encode(prompts) + prompt_token_ids = prompt_token_ids.to(self.device) + attention_masks = attention_masks.to(prompt_token_ids.device) + + # To draw multiple samples we repeat the prompt as many times + # as there are samples. We copy the FSMs and initialize the + # FSM states. + num_samples = self.num_samples + batch_size = len(prompts) + + prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0) + attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0) + fsm_states = [0 for _ in range(batch_size * num_samples)] + fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)] + weights = torch.zeros( + (batch_size * num_samples), + dtype=torch.float, + device=prompt_token_ids.device, + ) + + if rng is None: + rng = torch.Generator(device=prompt_token_ids.device) + rng.seed() + + states = sequence_generator( + self.model, + self.sampler, + fsms, + prompt_token_ids, + weights, + attention_masks, + fsm_states, + rng=rng, + ) + + def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]: + previously_generated_sequences = [ + "" for _ in range(batch_size) + ] * num_samples + num_generated = 0 + is_stop_at_reached = [False for _ in range(batch_size)] * num_samples + while True: + if (max_tokens and num_generated >= max_tokens) or all( + is_stop_at_reached + ): + return + try: + sequence = next(states) + num_generated += 1 + except StopIteration: + return + generated_token_ids = sequence.token_ids[:, -num_generated:] + generated_sequences = self.tokenizer.decode(generated_token_ids) + if stop_sequences: + is_stop_at_reached = [ + stop + or self.is_stop_sequence_found( + [generated_sequence], stop_sequences + ) + for generated_sequence, stop in zip( + generated_sequences, is_stop_at_reached + ) + ] + + generated_sequences = [ + self.format_sequence( + self.strip_stop_sequences(sequence, stop_sequences) + ) + if stop + else sequence + for sequence, stop in zip( + generated_sequences, is_stop_at_reached + ) + ] + next_tokens = [ + token[len(sequence) :] + for token, sequence, stop in zip( + generated_sequences, + previously_generated_sequences, + is_stop_at_reached, + ) + ] + previously_generated_sequences = generated_sequences + # We reshape the output to (batch_size, sample_size) + output: List[List[str]] = list() + for i in range(batch_size): + output.append(next_tokens[i : i + num_samples]) + + # We remove leading dimensions for the output + if batch_size == 1 and num_samples == 1: + yield output[0][0] + elif batch_size == 1: + yield output[0] + elif num_samples == 1: + yield [samples[0] for samples in output] + else: + yield output + + return token_generator() + + +@dataclass(frozen=True) +class GenerationParameters: + """Generation parameters used in Outlines' public API.""" + + max_tokens: Optional[int] + stop_at: Optional[Union[str, List[str]]] + seed: Optional[int] + + +@dataclass(frozen=True) +class SamplingParameters: + """Sampling parameters available in Outlines.""" + + sampler: str + num_samples: int = 1 + top_p: Optional[float] = None + top_k: Optional[int] = None + temperature: Optional[float] = None + + +class SequenceGeneratorAdapter: + """Class used to unify the interface to the model providers' + generation functions. + + Attributes + ---------- + model + The wrapped model. + logits_processor + The logits processor to use to generate text. + sampler + The sampler to use to generate text. + + """ + + def __init__(self, model, logits_processor, sampler): + self.model = model + self.logits_processor = logits_processor + + if isinstance(sampler, MultinomialSampler): + self.sampling_params = SamplingParameters( + "multinomial", + sampler.samples, + sampler.top_p, + sampler.top_k, + sampler.temperature, + ) + elif isinstance(sampler, GreedySampler): + self.sampling_params = SamplingParameters( + "greedy", sampler.samples, None, None, 0.0 + ) + elif isinstance(sampler, BeamSearchSampler): + self.sampling_params = SamplingParameters( + "beam_search", sampler.samples, None, None, 1.0 + ) + + def prepare_generation_parameters( + self, + max_tokens: Optional[int], + stop_at: Optional[Union[str, List[str]]], + seed: Optional[int], + ): + if isinstance(stop_at, str): + stop_at = [stop_at] + + generation_params = GenerationParameters( + max_tokens, + stop_at, + seed, + ) + + return generation_params + + def format_sequence(self, sequence: str) -> FormattedOutput: + """Translate the generated sequence to another type. + + This method is for instance overridden when generating JSON to either + return a dictionnary or a Pydantic model. + + Parameters + ---------- + sequence + A generated sequences. + + Returns + ------- + The formatted sequence. + + """ + return sequence + + def __call__( + self, + prompts: Union[str, List[str]], + max_tokens: Optional[int] = None, + stop_at: Optional[Union[str, List[str]]] = None, + seed: Optional[int] = None, + **model_specific_params, + ): + """Generate text from a prompt of list of prompts.""" + + def format(sequences): + """Apply formatting to every string in a completion.""" + if isinstance(sequences, list): + return [format(sequence) for sequence in sequences] + else: + return self.format_sequence(sequences) + + generation_params = self.prepare_generation_parameters( + max_tokens, stop_at, seed + ) + + completions = self.model.generate( + prompts, + generation_params, + self.logits_processor, + self.sampling_params, + **model_specific_params, + ) + + return format(completions) + + def stream( + self, + prompts: Union[str, List[str]], + max_tokens: Optional[int] = None, + stop_at: Optional[Union[str, List[str]]] = None, + seed: Optional[int] = None, + **model_specific_params, + ): + """Return a text generator from a prompt or a list of prompts.""" + generation_params = self.prepare_generation_parameters( + max_tokens, stop_at, seed + ) + return self.model.stream( + prompts, + generation_params, + self.logits_processor, + self.sampling_params, + **model_specific_params, + ) diff --git a/build/lib/outlines/generate/cfg.py b/build/lib/outlines/generate/cfg.py new file mode 100644 index 000000000..0a6698b08 --- /dev/null +++ b/build/lib/outlines/generate/cfg.py @@ -0,0 +1,64 @@ +from functools import singledispatch + +from outlines.fsm.guide import CFGGuide +from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter +from outlines.models import OpenAI +from outlines.models.llamacpp import LlamaCpp +from outlines.models.vllm import VLLM +from outlines.samplers import Sampler, multinomial + + +@singledispatch +def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenerator: + """Generate text in the language of a Context-Free Grammar + + Arguments + --------- + model: + An instance of `Transformer` that represents a model from the + `transformers` library. + sampler: + The sampling algorithm to use to generate token ids from the logits + distribution. + + Returns + ------- + A `SequenceGenerator` instance that generates text. + + """ + fsm = CFGGuide(cfg_str, model.tokenizer) + device = model.device + generator = SequenceGenerator(fsm, model, sampler, device) + + return generator + + +@cfg.register(VLLM) +def cfg_vllm( + model: VLLM, + cfg_str: str, + sampler: Sampler = multinomial(), +): + raise NotImplementedError( + "The CFG Logits processor is not available for the vLLM integration." + ) + + +@cfg.register(LlamaCpp) +def cfg_llamacpp( + model: LlamaCpp, + cfg_str: str, + sampler: Sampler = multinomial(), +): + from outlines.integrations.llamacpp import CFGLogitsProcessor + + logits_processor = CFGLogitsProcessor(cfg_str, model.model) + return SequenceGeneratorAdapter(model, logits_processor, sampler) + + +@cfg.register(OpenAI) +def cfg_openai(model, cfg_str: str, sampler: Sampler = multinomial()): + raise NotImplementedError( + "Cannot use grammar-structured generation with an OpenAI model" + + "due to the limitations of the OpenAI API." + ) diff --git a/build/lib/outlines/generate/choice.py b/build/lib/outlines/generate/choice.py new file mode 100644 index 000000000..6718f26b2 --- /dev/null +++ b/build/lib/outlines/generate/choice.py @@ -0,0 +1,36 @@ +from functools import singledispatch +from typing import Callable, List + +from outlines.generate.api import SequenceGenerator +from outlines.models import OpenAI +from outlines.samplers import Sampler, multinomial + +from .regex import regex + + +@singledispatch +def choice( + model, choices: List[str], sampler: Sampler = multinomial() +) -> SequenceGenerator: + regex_str = r"(" + r"|".join(choices) + r")" + + generator = regex(model, regex_str, sampler) + generator.format_sequence = lambda x: x + + return generator + + +@choice.register(OpenAI) +def choice_openai( + model: OpenAI, choices: List[str], sampler: Sampler = multinomial() +) -> Callable: + if not isinstance(sampler, multinomial): + raise NotImplementedError( + r"The OpenAI API does not support any other sampling algorithm " + + "that the multinomial sampler." + ) + + def generate_choice(prompt: str, max_tokens: int = 1): + return model.generate_choice(prompt, choices, max_tokens) + + return generate_choice diff --git a/build/lib/outlines/generate/format.py b/build/lib/outlines/generate/format.py new file mode 100644 index 000000000..d87a3fe70 --- /dev/null +++ b/build/lib/outlines/generate/format.py @@ -0,0 +1,45 @@ +from functools import singledispatch + +from outlines.fsm.types import python_types_to_regex +from outlines.generate.api import SequenceGenerator +from outlines.models import OpenAI +from outlines.samplers import Sampler, multinomial + +from .regex import regex + + +@singledispatch +def format(model, python_type, sampler: Sampler = multinomial()) -> SequenceGenerator: + """Generate structured data that can be parsed as a Python type. + + Parameters + ---------- + model: + An instance of `Transformer` that represents a model from the + `transformers` library. + python_type: + A Python type. The output of the generator must be parseable into + this type. + sampler: + The sampling algorithm to use to generate token ids from the logits + distribution. + + Returns + ------- + A `SequenceGenerator` instance that generates text constrained by the Python type + and translates this text into the corresponding type. + + """ + regex_str, format_fn = python_types_to_regex(python_type) + generator = regex(model, regex_str, sampler) + generator.format_sequence = format_fn + + return generator + + +@format.register(OpenAI) +def format_openai(model, python_type, sampler: Sampler = multinomial()): + raise NotImplementedError( + "Cannot use Python type-structured generation with an OpenAI model" + + "due to the limitations of the OpenAI API." + ) diff --git a/build/lib/outlines/generate/fsm.py b/build/lib/outlines/generate/fsm.py new file mode 100644 index 000000000..80db350f0 --- /dev/null +++ b/build/lib/outlines/generate/fsm.py @@ -0,0 +1,14 @@ +import interegular + +from outlines.fsm.guide import RegexGuide +from outlines.generate.api import SequenceGenerator +from outlines.samplers import Sampler, multinomial + + +def fsm( + model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial() +) -> SequenceGenerator: + fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer) + device = model.device + generator = SequenceGenerator(fsm, model, sampler, device) + return generator diff --git a/build/lib/outlines/generate/generator.py b/build/lib/outlines/generate/generator.py new file mode 100644 index 000000000..e506aa035 --- /dev/null +++ b/build/lib/outlines/generate/generator.py @@ -0,0 +1,312 @@ +import dataclasses +import math +from typing import TYPE_CHECKING, Callable, Iterable, Iterator, List, Optional, Tuple + +if TYPE_CHECKING: + import torch + + from outlines.fsm.guide import Guide + + +class ContextLengthExceededError(Exception): + pass + + +@dataclasses.dataclass(frozen=True) +class GenerationState: + token_ids: "torch.Tensor" + kv_cache: "torch.Tensor" + logits: "torch.Tensor" + weights: "torch.Tensor" + fsm_states: List[int] + + +def sequence_generator( + model: Callable, + sampler: Callable, + fsms: List["Guide"], + token_ids: "torch.Tensor", + sequence_weights: "torch.Tensor", + attention_masks: "torch.Tensor", + fsm_states: List[int], + rng: "torch.Generator", +) -> Iterator[GenerationState]: + """Generates sequences of tokens. + + Parameters + ---------- + model + A callable that generates a probability distribution over the + vocabulary when passed a tensor of token ids. + sampler + A callable that returns the next token ids, their ancestor sequence and + the updated sequence weights when passed a distribution over the + vocabulary. + token_ids + A tensor of token ids on which the sequence distribution is conditioned, of + shape ``(n_seqs, n_prompt_tokens)`` + sequence_weights + A tensor that contains the initial weights of the sequences, of shape + ``(n_seqs,)`` + attention_masks + A tensor of tensors that represent the tokens considered at the attention + layer, of shape ``(n_seqs, n_prompt_tokens)``. + fsms + List of finite-state machines that drive the text generation, + one for each sequence in the batch. + fsm_states + The initial states of the finite-state machine for each sequence in the batch. + + Yields + ------ + A new sequence. + + """ + import torch + + if rng is None: + rng = torch.Generator() + + kv_cache = None + + while True: + try: + logits, kv_cache = model(token_ids, attention_masks, kv_cache) + except IndexError: # Exceeding the context length + raise ContextLengthExceededError( + "The input length exceeds the context length of the model." + ) + + allowed_tokens = get_allowed_tokens(fsms, fsm_states) + biased_logits = bias_logits(logits, allowed_tokens) + next_token_ids, ancestors, sequence_weights = sampler( + biased_logits, sequence_weights, rng + ) + + token_ids = update_token_ids(token_ids, next_token_ids, ancestors) + attention_masks = update_attention_masks(attention_masks, ancestors) + kv_cache = reorder_kv_cache(kv_cache, ancestors) + if len(ancestors) > 1: + fsms = reorder_fsms(fsms, ancestors) + fsm_states = reorder_fsm_states(fsm_states, ancestors) + + fsm_states = get_next_fsm_states(fsms, fsm_states, next_token_ids) + is_finished = is_generation_finished(fsms, fsm_states) + + if is_finished: + yield GenerationState( + token_ids, + kv_cache, + logits, + sequence_weights, + fsm_states, + ) + return + + yield GenerationState( + token_ids, + kv_cache, + logits, + sequence_weights, + fsm_states, + ) + + +def get_next_fsm_states( + fsms: List["Guide"], fsm_states: List[int], next_token_ids: "torch.Tensor" +) -> List[int]: + """ + + Parameters + ---------- + fsm + The finite-state machine used to monitor this batch. + next_token_ids + The tokens that were just generated. + + Returns + ------- + A `torch.Tensor` object that represents the next logit mask. + + """ + return [ + fsm.get_next_state(fsm_state, int(token_id[0])) + for fsm, fsm_state, token_id in zip(fsms, fsm_states, next_token_ids) + ] + + +def get_allowed_tokens( + fsms: List["Guide"], fsm_states: List[int] +) -> List[Optional[Iterable[int]]]: + """Get the new instructions for each sequence from the finite-state machine. + + Parameters + ---------- + fsm + The finite-state machine used to monitor this batch. + fsm_states + The FSM states corresponding to each sequence in the batch. + + Returns + ------- + A nested list that contains the ids of the logits to keep. + + """ + return [ + fsm.get_next_instruction(state).tokens for fsm, state in zip(fsms, fsm_states) + ] + + +def is_generation_finished(fsms: List["Guide"], fsm_states: List[int]) -> bool: + """Determine if the generation is finished. + + A generation is considered finished if the FSM of every sequence in the + batch is in a final state. + + A better solution is to return finished sequences as soon as their FSM + is in a final state. + + Parameters + ---------- + fsm + The finite-state machine used to monitor this batch. + fsm_states + The FSM states corresponding to each sequence in the batch. + + Returns + ------- + Whether all sequences are finished sampling. + + """ + return all([fsm.is_final_state(state) for fsm, state in zip(fsms, fsm_states)]) + + +def update_token_ids( + token_ids: "torch.Tensor", next_token_ids: "torch.Tensor", ancestors: "torch.Tensor" +) -> "torch.Tensor": + """Append the sampled tokens to the running sequence of tokens. + + Parameters + ---------- + token_ids + The current token sequences + next_token_ids + The tokens that were just generated and that we need to append + to the existing sequences. + ancestors + The sequences to which the token ids need to be added. + + Returns + ------- + A new sequence of token ids that contains the tokens that were + just generated. + + """ + import torch + + token_ids = torch.index_select(token_ids, 0, ancestors) + return torch.concatenate([token_ids, next_token_ids], dim=-1) + + +def update_attention_masks( + attention_masks: "torch.Tensor", ancestors: "torch.Tensor" +) -> "torch.Tensor": + """Expand the attention masks. + + Parameters + ---------- + attention_masks + The attention masks for each sequence in the batch. + ancestors + The sequences to which the token ids need to be added. + + Returns + ------- + The attention masks padded with 1s. + + """ + import torch + + attention_masks = torch.index_select(attention_masks, 0, ancestors) + return torch.concatenate( + [ + attention_masks, + torch.ones( + attention_masks.shape[:-1] + (1,), device=attention_masks.device + ), + ], + axis=-1, + ) + + +def reorder_fsms(fsms: List["Guide"], ancestors: "torch.Tensor") -> List["Guide"]: + reordered_fsms = [] + for ancestor in ancestors: + reordered_fsms.append(fsms[ancestor].copy()) + + return reordered_fsms + + +def reorder_fsm_states(fsm_states: List[int], ancestors: "torch.Tensor") -> List[int]: + reordered_states = [] + for ancestor in ancestors: + reordered_states.append(fsm_states[ancestor]) + + return reordered_states + + +def reorder_kv_cache( + kv_cache: Optional[Tuple], ancestors: "torch.Tensor" +) -> Optional[Tuple]: + """Re-order the KV-cache based on the ancestors. + + In transformers, the object that stores the KV-cache is a tuple who elements + are the key cache and the value cache. Each of these caches are tuples where + each element correpond to a layer. To each layer corresponds a tensor whose + first dimension is the batch size. + + """ + import torch + + if kv_cache is None: + return None + + new_kv_cache: Tuple = tuple() + for cache_item in kv_cache: + new_cache_item: Tuple = tuple() + for layer in cache_item: + layer = torch.index_select(layer, 0, ancestors.to(layer.device)) + new_cache_item += (layer,) + new_kv_cache += (new_cache_item,) + + return new_kv_cache + + +def bias_logits(logits: "torch.Tensor", allowed_token_ids: List) -> "torch.Tensor": + """Mask the logits. + + The function iterates over a nested list where each list corresponds to the + indices that need to be masked for each row in the array. + + Parameters + ---------- + logits + Two dimensional tensor that contains the next-token probability + distribution. + allowed_token_ids + A list that contains the tokens that can be generated by the model. + + Returns + ------- + A view of the original logits tensor where some values are masked. + + """ + import torch + + biased_logits = torch.full_like(logits, -math.inf, device=logits.device) + for i, ids in enumerate(allowed_token_ids): + if ids is not None: + biased_logits[i, ids] = logits[i, ids] + else: + biased_logits[i] = logits[i] + return biased_logits diff --git a/build/lib/outlines/generate/json.py b/build/lib/outlines/generate/json.py new file mode 100644 index 000000000..3837f72b6 --- /dev/null +++ b/build/lib/outlines/generate/json.py @@ -0,0 +1,78 @@ +import json as pyjson +from functools import singledispatch +from typing import Callable, Optional, Union + +from pydantic import BaseModel + +from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature +from outlines.generate.api import SequenceGenerator +from outlines.models import OpenAI +from outlines.samplers import Sampler, multinomial + +from .regex import regex + + +@singledispatch +def json( + model, + schema_object: Union[str, object, Callable], + sampler: Sampler = multinomial(), + whitespace_pattern: Optional[str] = None, +) -> SequenceGenerator: + """ + Generate structured JSON data with a `Transformer` model based on a specified JSON Schema. + + Parameters + ---------- + model: + An instance of `Transformer` that represents a model from the + `transformers` library. + schema_object: + The JSON Schema to generate data for. Can be a JSON string, a Pydantic model, or a callable + that returns a JSON schema. + sampler: + The sampling algorithm to use to generate token ids from the logits + distribution. + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string literals) + Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + + Returns + ------- + A `SequenceGenerator` instance that generates text constrained by the schema_object and + transforms the result if BaseModel is used. + + """ + if isinstance(schema_object, type(BaseModel)): + schema = pyjson.dumps(schema_object.model_json_schema()) + regex_str = build_regex_from_schema(schema, whitespace_pattern) + generator = regex(model, regex_str, sampler) + generator.format_sequence = lambda x: schema_object.parse_raw(x) + elif callable(schema_object): + schema = pyjson.dumps(get_schema_from_signature(schema_object)) + regex_str = build_regex_from_schema(schema, whitespace_pattern) + generator = regex(model, regex_str, sampler) + generator.format_sequence = lambda x: pyjson.loads(x) + elif isinstance(schema_object, str): + schema = schema_object + regex_str = build_regex_from_schema(schema, whitespace_pattern) + generator = regex(model, regex_str, sampler) + generator.format_sequence = lambda x: pyjson.loads(x) + else: + raise ValueError( + f"Cannot parse schema {schema_object}. The schema must be either " + + "a Pydantic object, a function or a string that contains the JSON " + + "Schema specification" + ) + + return generator + + +@json.register(OpenAI) +def json_openai( + model, schema_object: Union[str, object, Callable], sampler: Sampler = multinomial() +): + raise NotImplementedError( + "Cannot use JSON Schema-structure generation with an OpenAI model " + + "due to the limitations of the OpenAI API" + ) diff --git a/build/lib/outlines/generate/regex.py b/build/lib/outlines/generate/regex.py new file mode 100644 index 000000000..ceea5d994 --- /dev/null +++ b/build/lib/outlines/generate/regex.py @@ -0,0 +1,73 @@ +from functools import singledispatch + +from outlines.fsm.guide import RegexGuide +from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter +from outlines.models import OpenAI +from outlines.models.llamacpp import LlamaCpp +from outlines.models.vllm import VLLM +from outlines.samplers import Sampler, multinomial + + +@singledispatch +def regex(model, regex_str: str, sampler: Sampler = multinomial()): + """Generate structured text in the language of a regular expression. + + Parameters + ---------- + model: + An instance of `Transformer` that represents a model from the + `transformers` library. + regex_str: + The regular expression that the output must follow. + sampler: + The sampling algorithm to use to generate token ids from the logits + distribution. + + Returns + ------- + A `SequenceGenerator` instance that generates text constrained by the + regular expression. + + """ + fsm = RegexGuide(regex_str, model.tokenizer) + + device = model.device + generator = SequenceGenerator(fsm, model, sampler, device) + + return generator + + +@regex.register(LlamaCpp) +def regex_llamacpp( + model: LlamaCpp, + regex_str: str, + sampler: Sampler = multinomial(), +): + from outlines.integrations.llamacpp import RegexLogitsProcessor + + logits_processor = RegexLogitsProcessor(regex_str, llm=model.model) + return SequenceGeneratorAdapter(model, logits_processor, sampler) + + +@regex.register(VLLM) +def regex_vllm( + model: VLLM, + regex_str: str, + sampler: Sampler = multinomial(), +): + from outlines.integrations.vllm import RegexLogitsProcessor + + logits_processor = RegexLogitsProcessor(regex_str, model.model) + return SequenceGeneratorAdapter(model, logits_processor, sampler) + + +@regex.register(OpenAI) +def regex_openai( + model: OpenAI, + regex_str: str, + sampler: Sampler = multinomial(), +): + raise NotImplementedError( + "Cannot use regex-structured generation with an OpenAI model" + + "due to the limitations of the OpenAI API." + ) diff --git a/build/lib/outlines/generate/text.py b/build/lib/outlines/generate/text.py new file mode 100644 index 000000000..35031348d --- /dev/null +++ b/build/lib/outlines/generate/text.py @@ -0,0 +1,57 @@ +from functools import singledispatch + +from outlines.fsm.guide import StopAtEOSGuide +from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter +from outlines.models import VLLM, LlamaCpp, OpenAI +from outlines.samplers import Sampler, multinomial + + +@singledispatch +def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator: + """Generate text with a `Transformer` model. + + Note + ---- + Python 3.11 allows dispatching on Union types and + this should greatly simplify the code. + + Arguments + --------- + model: + An instance of `Transformer` that represents a model from the + `transformers` library. + sampler: + The sampling algorithm to use to generate token ids from the logits + distribution. + + Returns + ------- + A `SequenceGenerator` instance that generates text. + + """ + fsm = StopAtEOSGuide(model.tokenizer) + device = model.device + generator = SequenceGenerator(fsm, model, sampler, device) + + return generator + + +@text.register(VLLM) +def text_vllm(model: VLLM, sampler: Sampler = multinomial()): + return SequenceGeneratorAdapter(model, None, sampler) + + +@text.register(LlamaCpp) +def text_llamacpp(model: LlamaCpp, sampler: Sampler = multinomial()): + return SequenceGeneratorAdapter(model, None, sampler) + + +@text.register(OpenAI) +def text_openai(model: OpenAI, sampler: Sampler = multinomial()) -> OpenAI: + if not isinstance(sampler, multinomial): + raise NotImplementedError( + r"The OpenAI API does not support any other sampling algorithm " + + "than the multinomial sampler." + ) + + return model diff --git a/build/lib/outlines/grammars.py b/build/lib/outlines/grammars.py new file mode 100644 index 000000000..f0c122964 --- /dev/null +++ b/build/lib/outlines/grammars.py @@ -0,0 +1,14 @@ +from pathlib import Path + +GRAMMAR_PATH = Path(__file__).parent / "grammars" + + +def read_grammar(grammar_file_name, base_grammar_path=GRAMMAR_PATH): + """Read grammar file from default grammar path""" + full_path = base_grammar_path / grammar_file_name + with open(full_path) as file: + return file.read() + + +arithmetic = read_grammar("arithmetic.lark") +json = read_grammar("json.lark") diff --git a/build/lib/outlines/grammars/arithmetic.lark b/build/lib/outlines/grammars/arithmetic.lark new file mode 100644 index 000000000..2332650c6 --- /dev/null +++ b/build/lib/outlines/grammars/arithmetic.lark @@ -0,0 +1,18 @@ +?start: sum + +?sum: product +| sum "+" product -> add +| sum "-" product -> sub + +?product: atom +| product "*" atom -> mul +| product "/" atom -> div + +?atom: NUMBER -> number +| "-" atom -> neg +| "(" sum ")" + +%import common.NUMBER +%import common.WS_INLINE + +%ignore WS_INLINE diff --git a/build/lib/outlines/grammars/common.lark b/build/lib/outlines/grammars/common.lark new file mode 100644 index 000000000..801c27e97 --- /dev/null +++ b/build/lib/outlines/grammars/common.lark @@ -0,0 +1,80 @@ +// Adapted from https://github.com/lark-parser/lark/blob/master/lark/grammars/common.lark + +// Lark License: +// Copyright © 2017 Erez Shinan +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +// Basic terminals for common use + + +// +// Numbers +// + +DIGIT: "0".."9" +HEXDIGIT: "a".."f"|"A".."F"|DIGIT + +INT: DIGIT+ +SIGNED_INT: ["+"|"-"] INT +DECIMAL: INT "." INT? | "." INT + +// float = /-?\d+(\.\d+)?([eE][+-]?\d+)?/ +_EXP: ("e"|"E") SIGNED_INT +FLOAT: INT _EXP | DECIMAL _EXP? +SIGNED_FLOAT: ["+"|"-"] FLOAT + +NUMBER: FLOAT | INT +SIGNED_NUMBER: ["+"|"-"] NUMBER + +// +// TODO: Working escaped_string +// +UNESCAPED_STRING: /\"[^"]*\"/ + + + +// +// Names (Variables) +// +LCASE_LETTER: "a".."z" +UCASE_LETTER: "A".."Z" + +LETTER: UCASE_LETTER | LCASE_LETTER +WORD: LETTER+ + +CNAME: ("_"|LETTER) ("_"|LETTER|DIGIT)* + + +// +// Whitespace +// +WS_INLINE: (" "|/\t/)+ +WS: /[ \t\f\r\n]/+ + +CR : /\r/ +LF : /\n/ +NEWLINE: (CR? LF)+ + + +// Comments +SH_COMMENT: /#[^\n]*/ +CPP_COMMENT: /\/\/[^\n]*/ +C_COMMENT: "/*" /(.|\n)*?/ "*/" +SQL_COMMENT: /--[^\n]*/ diff --git a/build/lib/outlines/grammars/json.lark b/build/lib/outlines/grammars/json.lark new file mode 100644 index 000000000..72af448ce --- /dev/null +++ b/build/lib/outlines/grammars/json.lark @@ -0,0 +1,19 @@ +?start: value + +?value: object +| array +| UNESCAPED_STRING +| SIGNED_NUMBER -> number +| "true" -> true +| "false" -> false +| "null" -> null + +array : "[" [value ("," value)*] "]" +object : "{" [pair ("," pair)*] "}" +pair : UNESCAPED_STRING ":" value + +%import common.UNESCAPED_STRING +%import common.SIGNED_NUMBER +%import common.WS + +%ignore WS diff --git a/build/lib/outlines/integrations/__init__.py b/build/lib/outlines/integrations/__init__.py new file mode 100644 index 000000000..b0a90d5ea --- /dev/null +++ b/build/lib/outlines/integrations/__init__.py @@ -0,0 +1 @@ +"""Utility functions and classes used to integrate `outlines` into other packages.""" diff --git a/build/lib/outlines/integrations/llamacpp.py b/build/lib/outlines/integrations/llamacpp.py new file mode 100644 index 000000000..8e18c33e7 --- /dev/null +++ b/build/lib/outlines/integrations/llamacpp.py @@ -0,0 +1,191 @@ +"""Make LlamaCpp compatible with Outlines' structured generation. + + _______________________________ +/ Don't want to self-host? \ +\\ Try .json at http://dottxt.co / + ------------------------------- + \\ ^__^ + \\ (oo)\\_______ + (__)\\ )\\/\ + ||----w | + || || + +Copyright 2024- the Outlines developers + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import math +from typing import TYPE_CHECKING, Optional, Type, Union + +import numpy as np +import torch +from numpy.typing import NDArray +from pydantic import BaseModel + +from outlines.fsm.guide import CFGGuide, Guide, RegexGuide +from outlines.fsm.json_schema import build_regex_from_schema +from outlines.integrations.utils import convert_json_schema_to_str +from outlines.models.llamacpp import LlamaCppTokenizer + +if TYPE_CHECKING: + from llama_cpp import Llama + + +class LogitsProcessor: + """Bias LlamaCpp generation using a finite state machine. + + Attributes + ---------- + tokenizer + The tokenizer used to convert tokens to ids. + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__(self, tokenizer: LlamaCppTokenizer, fsm: Guide): + """A FSM-based logits processor. + + Parameters + ---------- + tokenizer + The tokenizer used to convert tokens to ids. + fsm + The finite state machine which is used to bias the logits. + """ + self.tokenizer = tokenizer + self._fsm_state = 0 + self.fsm: Guide = fsm + self._is_first_token = True + + def __call__( + self, input_ids: NDArray[np.int64], scores: NDArray[np.float32] + ) -> NDArray[np.float32]: + """Use the FSM to bias the logits before sampling the next token. + + Parameters + ---------- + input_ids + The input token ids. + scores + The logits. + + Returns + ------- + NDArray[np.float32] + The biased logits. + """ + if self._is_first_token: + self._is_first_token = False + else: + last_token = input_ids[-1] + self._fsm_state = self.fsm.get_next_state(self._fsm_state, last_token) + + allowed_tokens = self.fsm.get_next_instruction(self._fsm_state).tokens + + mask = torch.full((scores.shape[-1],), -math.inf, device="cpu").numpy() + mask[allowed_tokens] = 0 + biased_scores = scores + mask + + return biased_scores + + def copy(self) -> "LogitsProcessor": + """Return a copy of the logits processor.""" + return LogitsProcessor(tokenizer=self.tokenizer, fsm=self.fsm.copy()) + + +class RegexLogitsProcessor(LogitsProcessor): + """Bias LlamaCpp generation based on a regular expression. + + Attributes + ---------- + tokenizer + The tokenizer used to convert tokens to ids. + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__(self, regex_string: str, llm: "Llama"): + """Compile the FSM that drives the regex-guided generation. + + Parameters + ---------- + regex_string + A string that represents a regular expression + llm + The Llama model. + """ + tokenizer = LlamaCppTokenizer(model=llm) + fsm = RegexGuide(regex_string, tokenizer) + super().__init__(tokenizer=tokenizer, fsm=fsm) + + +class JSONLogitsProcessor(RegexLogitsProcessor): + """Bias LlamaCpp generation based on a JSON schema. + + Attributes + ---------- + tokenizer + The tokenizer used to convert tokens to ids. + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__( + self, + schema: Union[dict, Type[BaseModel], str], + llm: "Llama", + whitespace_pattern: Optional[str] = None, + ): + """Compile the FSM that drives the JSON-guided generation. + + Parameters + ---------- + schema + A JSON schema that encodes the structure we want the model to generate. + llm + The Llama model. + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string + literals). For example, to allow only a single space or newline with + `whitespace_pattern=r"[\n ]?"` + """ + schema_str = convert_json_schema_to_str(json_schema=schema) + regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + super().__init__(regex_string=regex_string, llm=llm) + + +class CFGLogitsProcessor(LogitsProcessor): + """Bias LlamaCpp generation based on a context-free grammar. + + Attributes + ---------- + llm + The Llama model. + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__(self, cfg_str: str, llm: "Llama"): + """Compile the FSM that drives the CFG-guided generation. + + Parameters + ---------- + cfg_str + A string that represents a grammar + llm + The Llama model. + """ + tokenizer = LlamaCppTokenizer(model=llm) + fsm = CFGGuide(cfg_string=cfg_str, tokenizer=tokenizer) + super().__init__(tokenizer=tokenizer, fsm=fsm) diff --git a/build/lib/outlines/integrations/transformers.py b/build/lib/outlines/integrations/transformers.py new file mode 100644 index 000000000..7c1bafd22 --- /dev/null +++ b/build/lib/outlines/integrations/transformers.py @@ -0,0 +1,159 @@ +"""Make Hugging Face transformers compatible with Outlines' structured generation. + + _______________________________ +/ Don't want to self-host? \ +\\ Try .json at http://dottxt.co / + ------------------------------- + \\ ^__^ + \\ (oo)\\_______ + (__)\\ )\\/\ + ||----w | + || || + +Copyright 2024- the Outlines developers + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from collections import defaultdict +from typing import DefaultDict, Iterable, Optional, Type, Union + +import torch +from pydantic import BaseModel +from transformers import Pipeline, PreTrainedTokenizerBase + +from outlines.fsm.guide import RegexGuide +from outlines.fsm.json_schema import build_regex_from_schema +from outlines.integrations.utils import adapt_tokenizer, convert_json_schema_to_str + + +class RegexPrefixAllowedTokens: + """Bias transformers generation based on a regular expression. + + Attributes + ---------- + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__( + self, + regex_string: str, + tokenizer_or_pipe: Union[PreTrainedTokenizerBase, Pipeline], + ): + """Compile the FSM that drives the regex-structured generation. + + Parameters + ---------- + regex_string + A string that represents a regular expression. + tokenizer_or_pipe + The tokenizer of the model, or the pipeline object. + + Raises + ------ + ValueError + If the `tokenizer_or_pipe` parameter is not a tokenizer or a pipeline. + """ + if isinstance(tokenizer_or_pipe, Pipeline): + tokenizer = tokenizer_or_pipe.tokenizer + elif isinstance(tokenizer_or_pipe, PreTrainedTokenizerBase): + tokenizer = tokenizer_or_pipe + else: + raise ValueError( + "The tokenizer_or_pipe parameter must be a tokenizer or a pipeline." + ) + assert isinstance(tokenizer, PreTrainedTokenizerBase) + tokenizer = adapt_tokenizer(tokenizer=tokenizer) + self.fsm = RegexGuide(regex_string=regex_string, tokenizer=tokenizer) + self._fsm_state: DefaultDict[int, int] = defaultdict(int) + + # The generated text with `transformers` include the input token IDs as well, + # so we use this attribute to keep track of the input token IDs. This allows us + # to reset the FSM state when the input token IDs change, as well as to only + # apply the FSM to the generated tokens. + self._prefix = [-1] + + def __call__(self, batch_id: int, sent: torch.Tensor) -> Optional[Iterable[int]]: + """Use the FSM to bias the logits before sampling the next token. + + Parameters + ---------- + batch_id + The index of the current batch. + sent + The tokens of the current sentence. + + Returns + ------- + List[int] + The indices of the tokens that are allowed to be sampled next. + """ + input_ids = sent.tolist() + + # If the prefix token IDs have changed we assume that we are dealing with a new + # sample and reset the FSM state + if input_ids[: len(self._prefix)] != self._prefix: + self._fsm_state = defaultdict(int) + self._prefix = input_ids + seq_id = hash(tuple([])) + + else: + # Remove the prefix token IDs from the input token IDs, as the FSM should + # only be applied to the generated tokens + input_ids = input_ids[len(self._prefix) :] + + last_token = input_ids[-1] + last_seq_id = hash(tuple(input_ids[:-1])) + seq_id = hash(tuple(input_ids)) + self._fsm_state[seq_id] = self.fsm.get_next_state( + state=self._fsm_state[last_seq_id], token_id=last_token + ) + + allowed_tokens = self.fsm.get_next_instruction( + state=self._fsm_state[seq_id] + ).tokens + return allowed_tokens + + +class JSONPrefixAllowedTokens(RegexPrefixAllowedTokens): + """Bias transformers generation based on a JSON schema. + + Attributes + ---------- + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__( + self, + schema: Union[dict, Type[BaseModel], str], + tokenizer_or_pipe: Union[PreTrainedTokenizerBase, Pipeline], + whitespace_pattern: Optional[str] = None, + ): + """Compile the FSM that drives the JSON-guided generation. + + Parameters + ---------- + schema + A schema that encodes the structure we want the model to generate. + tokenizer_or_pipe + The tokenizer of the model, or the pipeline object. + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string + literals). For example, to allow only a single space or newline with + `whitespace_pattern=r"[\n ]?"` + """ + schema_str = convert_json_schema_to_str(json_schema=schema) + regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + super().__init__(regex_string=regex_string, tokenizer_or_pipe=tokenizer_or_pipe) diff --git a/build/lib/outlines/integrations/utils.py b/build/lib/outlines/integrations/utils.py new file mode 100644 index 000000000..9ac4e2a4f --- /dev/null +++ b/build/lib/outlines/integrations/utils.py @@ -0,0 +1,103 @@ +"""Utility functions used in integrations with other packages. + + _______________________________ +/ Don't want to self-host? \ +\\ Try .json at http://dottxt.co / + ------------------------------- + \\ ^__^ + \\ (oo)\\_______ + (__)\\ )\\/\ + ||----w | + || || + +Copyright 2024- the Outlines developers + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json +from typing import Type, Union + +from pydantic import BaseModel +from transformers import SPIECE_UNDERLINE, PreTrainedTokenizerBase + + +def adapt_tokenizer(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBase: + """Adapt a tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of `transformers`. In + addition we need to handle the missing spaces to Llama's tokenizer to be able to + compile FSMs for this model. + + Parameters + ---------- + tokenizer + The tokenizer of the model. + + Returns + ------- + PreTrainedTokenizerBase + The adapted tokenizer. + """ + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: Union[str, bytes]) -> str: + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if ( + type(token) is str + and token.startswith(SPIECE_UNDERLINE) + or token == "<0x20>" + ): + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + + return tokenizer + + +def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str: + """Convert a JSON schema to a string. + + Parameters + ---------- + json_schema + The JSON schema. + + Returns + ------- + str + The JSON schema converted to a string. + + Raises + ------ + ValueError + If the schema is not a dictionary, a string or a Pydantic class. + """ + if isinstance(json_schema, dict): + schema_str = json.dumps(json_schema) + elif isinstance(json_schema, str): + schema_str = json_schema + elif issubclass(json_schema, BaseModel): + schema_str = json.dumps(json_schema.model_json_schema()) + else: + raise ValueError( + f"Cannot parse schema {json_schema}. The schema must be either " + + "a Pydantic class, a dictionary or a string that contains the JSON " + + "schema specification" + ) + return schema_str diff --git a/build/lib/outlines/integrations/vllm.py b/build/lib/outlines/integrations/vllm.py new file mode 100644 index 000000000..8421be4bf --- /dev/null +++ b/build/lib/outlines/integrations/vllm.py @@ -0,0 +1,177 @@ +"""Make vLLM compatible with Outlines' structured generation. + + _______________________________ +/ Don't want to self-host? \ +\\ Try .json at http://dottxt.co / + ------------------------------- + \\ ^__^ + \\ (oo)\\_______ + (__)\\ )\\/\ + ||----w | + || || + +Copyright 2024- the Outlines developers + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import math +from collections import defaultdict +from typing import TYPE_CHECKING, DefaultDict, List, Optional, Type, Union +import torch +from pydantic import BaseModel + +from outlines.fsm.guide import RegexGuide, Write, Generate +from outlines.fsm.json_schema import build_regex_from_schema +from outlines.integrations.utils import adapt_tokenizer, convert_json_schema_to_str +import logging + +from tqdm import tqdm + +if TYPE_CHECKING: + from vllm import LLM +logger = logging.getLogger(__name__) +logging.basicConfig(filename='/pvc/outlines.log', encoding='utf-8', level=logging.DEBUG) + + +class RegexLogitsProcessor: + """Bias vLLM generation based on a regular expression. + + Attributes + ---------- + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__(self, regex_string: str, llm: "LLM"): + """Compile the FSM that drives the regex-structured generation. + + Parameters + ---------- + regex_string + A string that represents a regular expression. + llm + The vLLM model. + + Raises + ------ + ValueError + If the provided LLM instance in `RegexLogitsProcessor` neither has a + `tokenizer` attribute or a `get_tokenizer` method. + """ + if hasattr(llm, "get_tokenizer"): + tokenizer = llm.get_tokenizer() + elif hasattr(llm, "tokenizer"): + if hasattr(llm.tokenizer, "tokenizer"): + tokenizer = llm.tokenizer.tokenizer + else: + tokenizer = llm.tokenizer + else: + raise ValueError( + "The provided LLM instance in `RegexLogitsProcessor` neither has a " + "`tokenizer` attribute or a `get_tokenizer` method." + ) + tokenizer = adapt_tokenizer(tokenizer=tokenizer) + self.fsm = RegexGuide(regex_string, tokenizer) + self._fsm_state: DefaultDict[int, int] = defaultdict(int) + self.mask_cache = None + + def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: + """Use the FSM to bias the logits before sampling the next token. + + Parameters + ---------- + input_ids + The tokens of the current sentence. + scores + The logits of the current sentence. + + Returns + ------- + torch.Tensor + The biased logits. + """ + + if self.mask_cache is None: + self.mask_cache = torch.full((max(self.fsm.states_to_token_maps.keys()) + 2, + scores.shape[-1]), -math.inf, + dtype=scores.dtype) + # fill mask_cache (keys should be from 0 to num_states) + for key in tqdm(self.fsm.states_to_token_maps.keys(), desc="Init mask cache"): + allowed_tokens = self.fsm.get_next_instruction(key).tokens + self.mask_cache[key][allowed_tokens] = 0 + self.mask_cache = self.mask_cache.to(scores.device, non_blocking=True) + + seq_id = hash(tuple(input_ids)) + + # Initialize the FSM state dictionary if the input_ids are empty, as this means + # that the input_ids are the first tokens of the sequence. + if len(input_ids) > 0: + last_token = input_ids[-1] + last_seq_id = hash(tuple(input_ids[:-1])) + self._fsm_state[seq_id] = self.fsm.get_next_state( + state=self._fsm_state[last_seq_id], token_id=last_token + ) + + state = self._fsm_state[seq_id] + mask = self.mask_cache[state] + + ''' + if state in self.mask_cache: + mask = self.mask_cache[state] + else: + allowed_tokens = self.fsm.get_next_instruction( + state=self._fsm_state[seq_id] + ).tokens + mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device) + mask[allowed_tokens] = 0 + self.mask_cache[state] = mask + logger.debug("state: " + str(state)) + ''' + + biased_scores = scores + mask + + return biased_scores + + +class JSONLogitsProcessor(RegexLogitsProcessor): + """Bias vLLM generation based on a JSON schema. + + Attributes + ---------- + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__( + self, + schema: Union[dict, Type[BaseModel], str], + llm: "LLM", + whitespace_pattern: Optional[str] = None, + ): + """Compile the FSM that drives the JSON-guided generation. + + Parameters + ---------- + schema + A JSON schema that encodes the structure we want the model to generate. + llm + The vLLM model. + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string + literals). For example, to allow only a single space or newline with + `whitespace_pattern=r"[\n ]?"` + """ + schema_str = convert_json_schema_to_str(json_schema=schema) + regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + super().__init__(regex_string=regex_string, llm=llm) diff --git a/build/lib/outlines/models/__init__.py b/build/lib/outlines/models/__init__.py new file mode 100644 index 000000000..3676e6ccc --- /dev/null +++ b/build/lib/outlines/models/__init__.py @@ -0,0 +1,17 @@ +"""Module that contains all the models integrated in outlines. + +We group the models in submodules by provider instead of theme (completion, chat +completion, diffusers, etc.) and use routing functions everywhere else in the +codebase. + +""" +from typing import Union + +from .exllamav2 import ExLlamaV2Model, exl2 +from .llamacpp import LlamaCpp, llamacpp +from .mamba import Mamba, mamba +from .openai import OpenAI, azure_openai, openai +from .transformers import Transformers, transformers +from .vllm import VLLM, vllm + +LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, Mamba] diff --git a/build/lib/outlines/models/exllamav2.py b/build/lib/outlines/models/exllamav2.py new file mode 100644 index 000000000..0ec6ef033 --- /dev/null +++ b/build/lib/outlines/models/exllamav2.py @@ -0,0 +1,232 @@ +import os +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Lora + from transformers import PreTrainedTokenizer + import torch + +from .transformers import TransformerTokenizer + + +class ExLlamaV2Model: + """Represents a `exl2` model.""" + + def __init__( + self, + model: "ExLlamaV2", + tokenizer: "PreTrainedTokenizer", + device, + cache: "ExLlamaV2Cache", + lora: Optional["ExLlamaV2Lora"] = None, + ): + self.device = device + self.model = model + self.tokenizer = TransformerTokenizer(tokenizer) + self.cache = cache + self.past_seq = None + self.lora = lora + + def forward(self, input_ids: "torch.LongTensor", *_): + """Compute a forward pass through the exl2 model.""" + import torch + + # Caching with past_seq + reset = True + seq_tensor = input_ids[0] + + if self.past_seq is not None: + min_length = min(self.past_seq.shape[0], seq_tensor.shape[0]) + indices = torch.nonzero( + ~torch.eq(self.past_seq[:min_length], seq_tensor[:min_length]) + ) + if len(indices) > 0: + longest_prefix = indices[0].item() + else: + longest_prefix = min_length + + if longest_prefix > 0: + reset = False + self.cache.current_seq_len = longest_prefix + if seq_tensor.shape[0] - longest_prefix > 1: + self.model.forward( + seq_tensor[longest_prefix:-1].view(1, -1), + self.cache, + preprocess_only=True, + loras=[self.lora], + ) + elif seq_tensor.shape[0] == longest_prefix: + self.cache.current_seq_len -= 1 + + if reset: + self.cache.current_seq_len = 0 + if seq_tensor.shape[0] > 1: + self.model.forward( + seq_tensor[:-1].view(1, -1), + self.cache, + preprocess_only=True, + loras=[self.lora], + ) + + self.past_seq = seq_tensor + + return self.model.forward( + seq_tensor[-1:].view(1, -1), self.cache, loras=[self.lora] + ) + + def __call__(self, input_ids: "torch.LongTensor", *_) -> "torch.FloatTensor": + logits = self.forward(input_ids) + next_token_logits = logits[..., -1, :] + + return next_token_logits, None + + def update_lora(self, lora_path: Optional[str] = None): + """ + Update and apply the LoRA to the model. + + Args: + lora_path (Optional[str]): The path to the LoRA directory. If None, the LoRA will be unloaded. + """ + try: + from exllamav2 import ExLlamaV2Lora + except ImportError: + raise ImportError( + "The `exllamav2` library needs to be installed in order to use `exllamav2` models." + ) + if lora_path is None: + if self.lora is not None: + print(" -- Unloading LoRA...") + self.lora = None + else: + self.lora = ExLlamaV2Lora.from_directory(self.model, lora_path) + print(" -- Loading LoRA...") + + +def exl2( + model_path: str, + device: str, + max_seq_len: Optional[int] = None, + scale_pos_emb: Optional[float] = None, + scale_alpha_value: Optional[float] = None, + no_flash_attn: Optional[bool] = None, + num_experts_per_token: Optional[int] = None, + cache_8bit: bool = False, + cache_q4: bool = False, + tokenizer_kwargs: dict = {}, + gpu_split: Optional[str] = None, + low_mem: Optional[bool] = None, + verbose: Optional[bool] = None, +) -> ExLlamaV2Model: + """ + Load an ExLlamaV2 model. + + Parameters + ---------- + model_path (str) + Path to the model directory. + device (str) + Device to load the model on. Pass in 'cuda' for GPU or 'cpu' for CPU + max_seq_len (Optional[int], optional) + Maximum sequence length. Defaults to None. + scale_pos_emb (Optional[float], optional) + Scale factor for positional embeddings. Defaults to None. + scale_alpha_value (Optional[float], optional) + Scale alpha value. Defaults to None. + no_flash_attn (Optional[bool], optional) + Disable flash attention. Defaults to None. + num_experts_per_token (Optional[int], optional) + Number of experts per token. Defaults to None. + cache_8bit (bool, optional) + Use 8-bit cache. Defaults to False. + cache_q4 (bool, optional) + Use Q4 cache. Defaults to False. + tokenizer_kwargs (dict, optional) + Additional keyword arguments for the tokenizer. Defaults to {}. + gpu_split (str) + \"auto\", or VRAM allocation per GPU in GB. Auto will use exllama's autosplit feature + low_mem (bool, optional) + Enable VRAM optimizations, potentially trading off speed + verbose (bool, optional) + Enable if you want debugging statements + + Returns + ------- + An `ExLlamaV2Model` instance. + + Raises + ------ + `ImportError` if the `exllamav2` library is not installed. + + """ + try: + from exllamav2 import ( + ExLlamaV2, + ExLlamaV2Cache, + ExLlamaV2Cache_8bit, + ExLlamaV2Cache_Q4, + ExLlamaV2Config, + ) + from transformers import AutoTokenizer + except ImportError: + raise ImportError( + "The `exllamav2`, `transformers` and `torch` libraries needs to be installed in order to use `exllamav2` models." + ) + + # Load tokenizer + if not verbose: + print(" -- Loading tokenizer...") + tokenizer_kwargs.setdefault("padding_side", "left") + tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_kwargs) + # tokenizer = TransformerTokenizer(model_path, **tokenizer_kwargs) + + # Check fasttensors for config + if os.name != "nt": + use_fasttensors = True + else: + use_fasttensors = False + + # Create config + config = ExLlamaV2Config() + config.model_dir = model_path + config.fasttensors = use_fasttensors + config.prepare() + + # Set config options + if max_seq_len is not None: + config.max_seq_len = max_seq_len + if scale_pos_emb is not None: + config.scale_pos_emb = scale_pos_emb + if scale_alpha_value is not None: + config.scale_alpha_value = scale_alpha_value + if no_flash_attn is not None: + config.no_flash_attn = no_flash_attn + if num_experts_per_token is not None: + config.num_experts_per_token = num_experts_per_token + if low_mem: + config.set_low_mem() + + # Prepare the model from the config + model = ExLlamaV2(config) + + # Create cache + if cache_8bit: + cache = ExLlamaV2Cache_8bit(model, lazy=not model.loaded) + elif cache_q4: + cache = ExLlamaV2Cache_Q4(model, lazy=not model.loaded) + else: + cache = ExLlamaV2Cache(model, lazy=not model.loaded) + + # Load the model + split = None + if gpu_split and gpu_split != "auto": + split = [float(alloc) for alloc in gpu_split.split(",")] + if not verbose: + print(" -- Loading model...") + model.load(split) + + # Autoload if no GPU split was provided + if not model.loaded: + print(" -- Loading model...") + model.load_autosplit(cache) + + return ExLlamaV2Model(model, tokenizer, device, cache) diff --git a/build/lib/outlines/models/llamacpp.py b/build/lib/outlines/models/llamacpp.py new file mode 100644 index 000000000..840e1364f --- /dev/null +++ b/build/lib/outlines/models/llamacpp.py @@ -0,0 +1,391 @@ +import dataclasses +import pickle +import warnings +from typing import ( + TYPE_CHECKING, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + TypedDict, + Union, +) + +from typing_extensions import Unpack + +from outlines.generate.api import GenerationParameters, SamplingParameters +from outlines.models.tokenizer import Tokenizer + +if TYPE_CHECKING: + from llama_cpp import Llama, LogitsProcessorList + + +class LlamaCppTokenizer(Tokenizer): + def __init__(self, model: "Llama"): + self.eos_token_id = model.token_eos() + self.eos_token = model.tokenizer().decode([self.eos_token_id]) + self.pad_token_id = self.eos_token_id + self.special_tokens: Set[int] = set() + + self.vocabulary: Dict[str, int] = dict() + + self.tokenizer = model.tokenizer() + + # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved + try: + self.vocabulary = model.tokenizer_.hf_tokenizer.get_vocab() + except AttributeError: + # ### + for t in range(model.n_vocab()): + token_piece = model.tokenizer().decode([t]) + self.vocabulary[token_piece] = t + + # ensure stable ordering of vocabulary + self.vocabulary = { + tok: tok_id + for tok, tok_id in sorted(self.vocabulary.items(), key=lambda x: x[1]) + } + + self._hash = None + + def decode(self, token_ids: List[int]) -> List[str]: + decoded_bytes = self.tokenizer.detokenize(token_ids) + return [decoded_bytes.decode("utf-8", errors="ignore")] + + def encode( + self, prompt: Union[str, List[str]], add_bos: bool = True, special: bool = True + ) -> Tuple[List[int], List[int]]: + if isinstance(prompt, list): + raise NotImplementedError( + "llama-cpp-python tokenizer doesn't support batch tokenization" + ) + token_ids = self.tokenizer.tokenize( + prompt.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special + ) + # generate attention mask, missing from llama-cpp-python + attention_mask = [ + 1 if token_id != self.pad_token_id else 0 for token_id in token_ids + ] + return token_ids, attention_mask + + def convert_token_to_string(self, token: str) -> str: + return token + + def __eq__(self, other): + if not isinstance(other, LlamaCppTokenizer): + return False + return self.__getstate__() == other.__getstate__() + + def __hash__(self): + if self._hash is None: + self._hash = hash(pickle.dumps(self)) + return self._hash + + def __getstate__(self): + """Create a stable representation for outlines.caching""" + return ( + self.vocabulary, + self.eos_token_id, + self.eos_token, + self.pad_token_id, + sorted(self.special_tokens), + ) + + def __setstate__(self, state): + raise NotImplementedError("Cannot load a pickled llamacpp tokenizer") + + +class LlamaCppParams(TypedDict, total=False): + suffix: Optional[str] + temperature: float + top_p: float + min_p: float + typical_p: float + seed: int + max_tokens: int + logits_processor: "LogitsProcessorList" + stop: Optional[Union[str, List[str]]] + frequence_penalty: float + presence_penalty: float + repeat_penalty: float + top_k: int + tfs_z: float + mirostat_mode: int + mirostat_tau: float + mirostat_eta: float + stream: bool + + +class LlamaCpp: + """Represents a model provided by the `llama-cpp-python` library. + + We wrap models from model providing libraries in order to give all of + them the same interface in Outlines and allow users to easily switch + between providers. This class wraps the `llama_cpp.Llama` class from the + `llama-cpp-python` library. + + """ + + def __init__(self, model: "Llama"): + self.model = model + + def prepare_generation_parameters( + self, + generation_parameters: GenerationParameters, + sampling_parameters: SamplingParameters, + structure_logits_processor, + **llama_cpp_params: Unpack[LlamaCppParams], + ): + """Prepare the generation parameters. + + `llama-cpp-python` uses different default values + + """ + from llama_cpp import LogitsProcessorList + + max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters) + + # We update `llama_cpp_params` with the values the user passed to the + # generator. + if "stop" not in llama_cpp_params: + llama_cpp_params["stop"] = stop_at + if "seed" not in llama_cpp_params: + llama_cpp_params["seed"] = seed + + # Somehow `llama-cpp-python` generates `max_tokens + 1` tokens + if "max_tokens" not in llama_cpp_params: + if max_tokens is not None: + llama_cpp_params["max_tokens"] = max_tokens - 1 + else: + llama_cpp_params["max_tokens"] = llama_cpp_params["max_tokens"] - 1 + + sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple( + sampling_parameters + ) + + # We update the `llama_cpp_params` with the sampling values that + # were specified by the user via the `Sampler` class, unless they + # are also specified in `llama_cpp_params`. We also disable other + # sampling methods that are enabled by default and reset the temperature + # value. + # + # See https://github.com/ggerganov/llama.cpp/blob/e11a8999b5690f810c2c99c14347f0834e68c524/common/sampling.h#L22 + # for the default values in `llama.cpp` and indications to disable the sampling modes. + # Mirostat sampling, tail-free sampling and all penalties are disabled by default. + # + # See https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__ + # for default values in `llama-cpp-python` + if sampler == "beam_search": + raise NotImplementedError( + "The `llama_cpp_python` library does not support Beam Search." + ) + if num_samples != 1: + raise NotImplementedError( + "The `llama_cpp_python` library does not allow to take several samples." + ) + if "top_p" not in llama_cpp_params: + if top_p is not None: + llama_cpp_params["top_p"] = top_p + else: + llama_cpp_params["top_p"] = 1.0 + + if "min_p" not in llama_cpp_params: + llama_cpp_params["min_p"] = 0.0 + + if "top_k" not in llama_cpp_params: + if top_k is not None: + llama_cpp_params["top_k"] = top_k + else: + llama_cpp_params["top_k"] = -1 + + if "temperature" not in llama_cpp_params: + if temperature is not None: + llama_cpp_params["temperature"] = temperature + else: + llama_cpp_params["temperature"] = 1.0 + + if "repeat_penalty" not in llama_cpp_params: + llama_cpp_params["repeat_penalty"] = 1.0 + + # The choice to stream or not should happen via the high-level API + llama_cpp_params["stream"] = False + + if structure_logits_processor is not None: + if "logits_processor" in llama_cpp_params: + llama_cpp_params["logits_processor"].append(structure_logits_processor) + else: + llama_cpp_params["logits_processor"] = LogitsProcessorList( + [structure_logits_processor] + ) + + return llama_cpp_params + + def generate( + self, + prompts: Union[str, List[str]], + generation_parameters: GenerationParameters, + structure_logits_processor, + sampling_parameters: SamplingParameters, + **llama_cpp_params: Unpack[LlamaCppParams], + ) -> str: + """Generate text using `llama-cpp-python`. + + Arguments + --------- + prompts + A prompt or list of prompts. + generation_parameters + An instance of `GenerationParameters` that contains the prompt, + the maximum number of tokens, stop sequences and seed. All the + arguments to `SequenceGeneratorAdapter`'s `__cal__` method. + logits_processor + The logits processor to use when generating text. + sampling_parameters + An instance of `SamplingParameters`, a dataclass that contains + the name of the sampler to use and related parameters as available + in Outlines. + llama_cpp_params + Keyword arguments that can be passed to + `llama_cpp_python.Llama.__call__`. The values in `llama_cpp_params` + supersede the values of the parameters in `generation_parameters` and + `sampling_parameters`. See the `llama_cpp_python` documentation for + a list of possible values: https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__ + + Returns + ------- + The generated text. + + """ + if not isinstance(prompts, str): + raise NotImplementedError( + "The `llama-cpp-python` library does not support batch inference." + ) + + llama_cpp_params = self.prepare_generation_parameters( + generation_parameters, + sampling_parameters, + structure_logits_processor, + **llama_cpp_params, + ) + completion = self.model(prompts, **llama_cpp_params) + result = completion["choices"][0]["text"] + + self.model.reset() + + return result + + def stream( + self, + prompts: Union[str, List[str]], + generation_parameters: GenerationParameters, + structure_logits_processor, + sampling_parameters: SamplingParameters, + **llama_cpp_params: Unpack[LlamaCppParams], + ) -> Iterator[str]: + """Stream text using `llama-cpp-python`. + + Arguments + --------- + prompts + A prompt or list of prompts. + generation_parameters + An instance of `GenerationParameters` that contains the prompt, + the maximum number of tokens, stop sequences and seed. All the + arguments to `SequenceGeneratorAdapter`'s `__cal__` method. + logits_processor + The logits processor to use when generating text. + sampling_parameters + An instance of `SamplingParameters`, a dataclass that contains + the name of the sampler to use and related parameters as available + in Outlines. + llama_cpp_params + Keyword arguments that can be passed to + `llama_cpp_python.Llama.__call__`. The values in `llama_cpp_params` + supersede the values of the parameters in `generation_parameters` and + `sampling_parameters`. See the `llama_cpp_python` documentation for + a list of possible values: https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__ + + Returns + ------- + A generator that return strings. + + """ + + if not isinstance(prompts, str): + raise NotImplementedError( + "The `llama-cpp-python` library does not support batch inference." + ) + + llama_cpp_params = self.prepare_generation_parameters( + generation_parameters, + sampling_parameters, + structure_logits_processor, + **llama_cpp_params, + ) + llama_cpp_params["stream"] = True + generator = self.model(prompts, **llama_cpp_params) + + def token_generator() -> Iterator[str]: + while True: + try: + result = next(generator) + yield result["choices"][0]["text"] + except StopIteration: + self.model.reset() + return + + return token_generator() + + def load_lora(self, adapter_path: str): + if self.model._model.apply_lora_from_file( + adapter_path, + 1.0, + ): + raise RuntimeError(f"Failed to apply LoRA from lora path: {adapter_path}") + + +def llamacpp( + repo_id: str, filename: Optional[str] = None, **llamacpp_model_params +) -> LlamaCpp: + """Load a model from the `llama-cpp-python` library. + + We use the `Llama.from_pretrained` classmethod that downloads models + directly from the HuggingFace hub, instead of asking users to specify + a path to the downloaded model. One can still load a local model + by initializing `llama_cpp.Llama` directly. + + Arguments + --------- + repo_id + The name of the model repository. + filename: + A filename of glob pattern to match the model file in the repo. + llama_cpp_model_params + Llama-specific model parameters. See the `llama-cpp-python` documentation + for the full list: https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__ + + """ + from llama_cpp import Llama + + # Default to using the model's full context length + if "n_ctx" not in llamacpp_model_params: + llamacpp_model_params["n_ctx"] = 0 + + if "verbose" not in llamacpp_model_params: + llamacpp_model_params["verbose"] = False + + # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved + if "tokenizer" not in llamacpp_model_params: + warnings.warn( + "The pre-tokenizer in `llama.cpp` handles unicode improperly " + + "(https://github.com/ggerganov/llama.cpp/pull/5613)\n" + + "Outlines may raise a `RuntimeError` when building the regex index.\n" + + "To circumvent this error when using `models.llamacpp()` you may pass the argument" + + "`tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained()`\n" + ) + + model = Llama.from_pretrained(repo_id, filename, **llamacpp_model_params) + + return LlamaCpp(model) diff --git a/build/lib/outlines/models/mamba.py b/build/lib/outlines/models/mamba.py new file mode 100644 index 000000000..d3dabf669 --- /dev/null +++ b/build/lib/outlines/models/mamba.py @@ -0,0 +1,61 @@ +from typing import TYPE_CHECKING, Optional + +from .transformers import TransformerTokenizer + +if TYPE_CHECKING: + import torch + from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + from transformers import PreTrainedTokenizer + + +TOKENIZER_MODEL = "EleutherAI/gpt-neox-20b" + + +class Mamba: + """Represent a `mamba` model.""" + + def __init__( + self, model: "MambaLMHeadModel", tokenizer: "PreTrainedTokenizer", device + ): + self.device = device + self.model = model + self.tokenizer = TransformerTokenizer(tokenizer) + + def forward(self, input_ids: "torch.LongTensor", *_): + """Compute a forward pass through the mamba model.""" + + output = self.model(input_ids) + next_token_logits = output.logits[..., -1, :] + return next_token_logits, None + + def __call__(self, input_ids: "torch.LongTensor", *_) -> "torch.FloatTensor": + return self.forward(input_ids) + + +def mamba( + model_name: str, + device: Optional[str] = None, + model_kwargs: dict = {}, + tokenizer_kwargs: dict = {}, +): + try: + import torch + from mamba_ssm import MambaLMHeadModel + from transformers import AutoTokenizer + except ImportError: + raise ImportError( + "The `mamba_ssm`, `torch` and `transformer` libraries needs to be installed in order to use Mamba people." + ) + + if not torch.cuda.is_available(): + raise NotImplementedError("Mamba models can only run on GPU.") + else: + if device is None: + device = "cuda" + + model = MambaLMHeadModel.from_pretrained(model_name, device=device) + + tokenizer_kwargs.setdefault("padding_side", "left") + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL, **tokenizer_kwargs) + + return Mamba(model, tokenizer, device) diff --git a/build/lib/outlines/models/openai.py b/build/lib/outlines/models/openai.py new file mode 100644 index 000000000..cedb37420 --- /dev/null +++ b/build/lib/outlines/models/openai.py @@ -0,0 +1,467 @@ +"""Integration with OpenAI's API.""" +import functools +from dataclasses import asdict, dataclass, field, replace +from itertools import zip_longest +from typing import Callable, Dict, List, Optional, Set, Tuple, Union + +import numpy as np + +from outlines.base import vectorize +from outlines.caching import cache + +__all__ = ["OpenAI", "openai", "azure_openai"] + + +@dataclass(frozen=True) +class OpenAIConfig: + """Represents the parameters of the OpenAI API. + + The information was last fetched on 2023/11/20. We document below the + properties that are specific to the OpenAI API. Not all these properties are + supported by Outlines. + + Properties + ---------- + model + The name of the model. Available models can be found on OpenAI's website. + frequence_penalty + Number between 2.0 and -2.0. Positive values penalize new tokens based on + their existing frequency in the text, + logit_bias + Modifies the likelihood of specified tokens to appear in the completion. + Number between -100 (forbid) and +100 (only allows). + n + The number of completions to return for each prompt. + presence_penalty + Similar to frequency penalty. + response_format + Specifies the format the model must output. `{"type": "json_object"}` + enables JSON mode. + seed + Two completions with the same `seed` value should return the same + completion. This is however not guaranteed. + stop + Up to 4 words where the API will stop the completion. + temperature + Number between 0 and 2. Higher values make the output more random, while + lower values make it more deterministic. + top_p + Number between 0 and 1. Parameter for nucleus sampling. + user + A unique identifier for the end-user. + + """ + + model: str = "" + frequency_penalty: float = 0 + logit_bias: Dict[int, int] = field(default_factory=dict) + max_tokens: Optional[int] = None + n: int = 1 + presence_penalty: float = 0 + response_format: Optional[Dict[str, str]] = None + seed: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + temperature: float = 1.0 + top_p: int = 1 + user: str = field(default_factory=str) + + +class OpenAI: + """An object that represents the OpenAI API.""" + + def __init__( + self, + client, + config, + tokenizer=None, + system_prompt: Optional[str] = None, + ): + """Create an `OpenAI` instance. + + This class supports the standard OpenAI API, the Azure OpeanAI API as + well as compatible APIs that rely on the OpenAI client. + + Parameters + ---------- + client + An instance of the API's async client. + config + An instance of `OpenAIConfig`. Can be useful to specify some + parameters that cannot be set by calling this class' methods. + tokenizer + The tokenizer associated with the model the client connects to. + + """ + + self.client = client + self.tokenizer = tokenizer + self.config = config + + # We count the total number of prompt and generated tokens as returned + # by the OpenAI API, summed over all the requests performed with this + # model instance. + self.prompt_tokens = 0 + self.completion_tokens = 0 + + def __call__( + self, + prompt: Union[str, List[str]], + max_tokens: Optional[int] = None, + stop_at: Optional[Union[List[str], str]] = None, + *, + system_prompt: Optional[str] = None, + temperature: Optional[float] = None, + samples: Optional[int] = None, + ) -> np.ndarray: + """Call the OpenAI API to generate text. + + Parameters + ---------- + prompt + A string or list of strings that will be used to prompt the model + max_tokens + The maximum number of tokens to generate + stop_at + A string or array of strings which, such that the generation stops + when they are generated. + system_prompt + The content of the system message that precedes the user's prompt. + temperature + The value of the temperature used to sample tokens + samples + The number of completions to generate for each prompt + stop_at + Up to 4 words where the API will stop the completion. + + """ + if max_tokens is None: + max_tokens = self.config.max_tokens + if stop_at is None: + stop_at = self.config.stop + if temperature is None: + temperature = self.config.temperature + if samples is None: + samples = self.config.n + + config = replace(self.config, max_tokens=max_tokens, temperature=temperature, n=samples, stop=stop_at) # type: ignore + + response, prompt_tokens, completion_tokens = generate_chat( + prompt, system_prompt, self.client, config + ) + self.prompt_tokens += prompt_tokens + self.completion_tokens += completion_tokens + + return response + + def stream(self, *args, **kwargs): + raise NotImplementedError( + "Streaming is currently not supported for the OpenAI API" + ) + + def generate_choice( + self, + prompt: str, + choices: List[str], + max_tokens: Optional[int] = None, + system_prompt: Optional[str] = None, + ) -> str: + """Call the OpenAI API to generate one of several choices. + + Parameters + ---------- + prompt + A string or list of strings that will be used to prompt the model + choices + The list of strings between which we ask the model to choose + max_tokens + The maximum number of tokens to generate + system_prompt + The content of the system message that precedes the user's prompt. + + """ + if self.tokenizer is None: + raise ValueError( + "You must initialize the `OpenAI` class with a tokenizer to use `outlines.generate.choice`" + ) + + config = replace(self.config, max_tokens=max_tokens) + + greedy = False + decoded: List[str] = [] + encoded_choices_left: List[List[int]] = [ + self.tokenizer.encode(word) for word in choices + ] + + while len(encoded_choices_left) > 0: + max_tokens_left = max([len(tokens) for tokens in encoded_choices_left]) + transposed_choices_left: List[Set] = [ + {item for item in subset if item is not None} + for subset in zip_longest(*encoded_choices_left) + ] + + if not greedy: + mask = build_optimistic_mask(transposed_choices_left) + else: + mask = {} + for token in transposed_choices_left[0]: # build greedy mask + mask[token] = 100 + + if len(mask) == 0: + break + + config = replace(config, logit_bias=mask, max_tokens=max_tokens_left) + + response, prompt_tokens, completion_tokens = generate_chat( + prompt, system_prompt, self.client, config + ) + self.prompt_tokens += prompt_tokens + self.completion_tokens += completion_tokens + + encoded_response = self.tokenizer.encode(response) + + if encoded_response in encoded_choices_left: + decoded.append(response) + break + else: + ( + encoded_response, + encoded_choices_left, + ) = find_response_choices_intersection( + encoded_response, encoded_choices_left + ) + + if len(encoded_response) == 0: + greedy = True # next iteration will be "greedy" + continue + else: + decoded.append("".join(self.tokenizer.decode(encoded_response))) + + if len(encoded_choices_left) == 1: # only one choice left + choice_left = self.tokenizer.decode(encoded_choices_left[0]) + decoded.append(choice_left) + break + + greedy = False # after each success, stay with (or switch to) "optimistic" approach + + prompt = prompt + "".join(decoded) + + choice = "".join(decoded) + + return choice + + def generate_json(self): + """Call the OpenAI API to generate a JSON object.""" + raise NotImplementedError + + def __str__(self): + return self.__class__.__name__ + " API" + + def __repr__(self): + return str(self.config) + + +@functools.partial(vectorize, signature="(),(),(),()->(s),(),()") +async def generate_chat( + prompt: str, + system_prompt: Union[str, None], + client, + config: OpenAIConfig, +) -> Tuple[np.ndarray, int, int]: + """Call OpenAI's Chat Completion API. + + Parameters + ---------- + prompt + The prompt we use to start the generation. Passed to the model + with the "user" role. + system_prompt + The system prompt, passed to the model with the "system" role + before the prompt. + client + The API client + config + An `OpenAIConfig` instance. + + Returns + ------- + A tuple that contains the model's response(s) and usage statistics. + + """ + + @error_handler + @cache() + async def call_api(prompt, system_prompt, config): + responses = await client.chat.completions.create( + messages=system_message + user_message, + **asdict(config), # type: ignore + ) + return responses.model_dump() + + system_message = ( + [{"role": "system", "content": system_prompt}] if system_prompt else [] + ) + user_message = [{"role": "user", "content": prompt}] + + responses = await call_api(prompt, system_prompt, config) + + results = np.array( + [responses["choices"][i]["message"]["content"] for i in range(config.n)] + ) + usage = responses["usage"] + + return results, usage["prompt_tokens"], usage["completion_tokens"] + + +def find_longest_intersection(response: List[int], choice: List[int]) -> List[int]: + """Find the longest intersection between the response and the choice.""" + for i, (token_r, token_c) in enumerate(zip_longest(response, choice)): + if token_r != token_c: + return response[:i] + + return response + + +def find_response_choices_intersection( + response: List[int], choices: List[List[int]] +) -> Tuple[List[int], List[List[int]]]: + """Find the longest intersection between the response and the different + choices. + + Say the response is of the form `[1, 2, 3, 4, 5]` and we have the choices + `[[1, 2], [1, 2, 3], [6, 7, 8]` then the function will return `[1, 2, 3]` as the + intersection, and `[[]]` as the list of choices left. + + Parameters + ---------- + response + The model's response + choices + The remaining possible choices + + Returns + ------- + A tuple that contains the longest intersection between the response and the + different choices, and the choices which start with this intersection, with the + intersection removed. + + """ + max_len_prefix = 0 + choices_left = [] + longest_prefix = [] + for i, choice in enumerate(choices): + # Find the longest intersection between the response and the choice. + prefix = find_longest_intersection(response, choice) + + if len(prefix) > max_len_prefix: + max_len_prefix = len(prefix) + choices_left = [choice[len(prefix) :]] + longest_prefix = prefix + + elif len(prefix) == max_len_prefix: + choices_left.append(choice[len(prefix) :]) + + return longest_prefix, choices_left + + +def build_optimistic_mask( + transposed: List[Set[int]], max_mask_size: int = 300 +) -> Dict[int, int]: + """We build the largest mask possible. + + Tokens are added from left to right, so if the encoded choices are e.g. + `[[1,2], [3,4]]`, `1` and `3` will be added before `2` and `4`. + + Parameters + ---------- + transposed + A list of lists that contain the nth token of each choice. + + """ + mask: Dict[int, int] = {} + for tokens in transposed: + for token in tokens: + if len(mask) == max_mask_size: + return mask + mask[token] = 100 + + return mask + + +def error_handler(api_call_fn: Callable) -> Callable: + """Handle OpenAI API errors and missing API key.""" + + def call(*args, **kwargs): + import openai + + try: + return api_call_fn(*args, **kwargs) + except ( + openai.APITimeoutError, + openai.InternalServerError, + openai.RateLimitError, + ) as e: + raise OSError(f"Could not connect to the OpenAI API: {e}") + except ( + openai.AuthenticationError, + openai.BadRequestError, + openai.ConflictError, + openai.PermissionDeniedError, + openai.NotFoundError, + openai.UnprocessableEntityError, + ) as e: + raise e + + return call + + +def openai( + model_name: str, + api_key: Optional[str] = None, + config: Optional[OpenAIConfig] = None, +): + try: + import tiktoken + from openai import AsyncOpenAI + except ImportError: + raise ImportError( + "The `openai` and `tiktoken` libraries needs to be installed in order to use Outlines' OpenAI integration." + ) + + if config is not None: + config = replace(config, model=model_name) # type: ignore + else: + config = OpenAIConfig(model=model_name) + + client = AsyncOpenAI(api_key=api_key) + tokenizer = tiktoken.encoding_for_model(model_name) + + return OpenAI(client, config, tokenizer) + + +def azure_openai( + deployment_name: str, + model_name: Optional[str] = None, + azure_endpoint: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + config: Optional[OpenAIConfig] = None, +): + try: + import tiktoken + from openai import AsyncAzureOpenAI + except ImportError: + raise ImportError( + "The `openai` and `tiktoken` libraries needs to be installed in order to use Outlines' Azure OpenAI integration." + ) + + if config is not None: + config = replace(config, model=deployment_name) # type: ignore + if config is None: + config = OpenAIConfig(model=deployment_name) + + client = AsyncAzureOpenAI( + azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key + ) + tokenizer = tiktoken.encoding_for_model(model_name or deployment_name) + + return OpenAI(client, config, tokenizer) diff --git a/build/lib/outlines/models/tokenizer.py b/build/lib/outlines/models/tokenizer.py new file mode 100644 index 000000000..949414a44 --- /dev/null +++ b/build/lib/outlines/models/tokenizer.py @@ -0,0 +1,31 @@ +from typing import Dict, Hashable, List, Protocol, Set, Tuple, Union + +import numpy as np +from numpy.typing import NDArray + + +class Tokenizer(Hashable, Protocol): + eos_token: str + eos_token_id: int + pad_token_id: int + vocabulary: Dict[str, int] + special_tokens: Set[int] + + def encode( + self, prompt: Union[str, List[str]] + ) -> Tuple[NDArray[np.int64], NDArray[np.int64]]: + """Translate the input prompts into arrays of token ids and attention mask.""" + ... + + def decode(self, token_ids: NDArray[np.int64]) -> List[str]: + """Translate an array of token ids to a string or list of strings.""" + ... + + def convert_token_to_string(self, token: str) -> str: + """Convert a token to its equivalent string. + + This is for instance useful for BPE tokenizers where whitespaces are + represented by the special characted `Ġ`. This prevents matching a raw + token that includes `Ġ` with a string. + """ + ... diff --git a/build/lib/outlines/models/transformers.py b/build/lib/outlines/models/transformers.py new file mode 100644 index 000000000..fae9b8e74 --- /dev/null +++ b/build/lib/outlines/models/transformers.py @@ -0,0 +1,236 @@ +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +from datasets.fingerprint import Hasher + +from outlines.models.tokenizer import Tokenizer + +if TYPE_CHECKING: + import torch + from transformers import PreTrainedModel, PreTrainedTokenizer + +__all__ = ["transformers"] + + +KVCacheType = Tuple[Tuple["torch.DoubleTensor", "torch.DoubleTensor"], ...] + + +def get_llama_tokenizer_types(): + """Get all the Llama tokenizer types/classes that need work-arounds. + + When they can't be imported, a dummy class is created. + + """ + try: + from transformers.models.llama import LlamaTokenizer + except ImportError: + + class LlamaTokenizer: # type: ignore + pass + + try: + from transformers.models.llama import LlamaTokenizerFast + except ImportError: + + class LlamaTokenizerFast: # type: ignore + pass + + try: + from transformers.models.code_llama import CodeLlamaTokenizer + except ImportError: + + class CodeLlamaTokenizer: # type: ignore + pass + + try: + from transformers.models.code_llama import CodeLlamaTokenizerFast + except ImportError: + + class CodeLlamaTokenizerFast: # type: ignore + pass + + return ( + LlamaTokenizer, + LlamaTokenizerFast, + CodeLlamaTokenizer, + CodeLlamaTokenizerFast, + ) + + +class TransformerTokenizer(Tokenizer): + """Represents a tokenizer for models in the `transformers` library.""" + + def __init__(self, tokenizer: "PreTrainedTokenizer", **kwargs): + self.tokenizer = tokenizer + self.eos_token_id = self.tokenizer.eos_token_id + self.eos_token = self.tokenizer.eos_token + + if not self.tokenizer.pad_token_id: + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + self.pad_token_id = self.eos_token_id + else: + self.pad_token_id = self.tokenizer.pad_token_id + self.pad_token = self.tokenizer.pad_token + + self.special_tokens = set(self.tokenizer.all_special_tokens) + + self.vocabulary = self.tokenizer.get_vocab() + self.is_llama = isinstance(self.tokenizer, get_llama_tokenizer_types()) + + def encode( + self, prompt: Union[str, List[str]], **kwargs + ) -> Tuple["torch.LongTensor", "torch.LongTensor"]: + kwargs["padding"] = True + kwargs["return_tensors"] = "pt" + output = self.tokenizer(prompt, **kwargs) + return output["input_ids"], output["attention_mask"] + + def decode(self, token_ids: "torch.LongTensor") -> List[str]: + text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) + return text + + def convert_token_to_string(self, token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = self.tokenizer.convert_tokens_to_string([token]) + + if self.is_llama: + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + def __eq__(self, other): + if isinstance(other, type(self)): + if hasattr(self, "model_name") and hasattr(self, "kwargs"): + return ( + other.model_name == self.model_name and other.kwargs == self.kwargs + ) + else: + return other.tokenizer == self.tokenizer + return NotImplemented + + def __hash__(self): + return hash(Hasher.hash(self.tokenizer)) + + def __getstate__(self): + state = {"tokenizer": self.tokenizer} + return state + + def __setstate__(self, state): + self.__init__(state["tokenizer"]) + + +class Transformers: + """Represents a `transformers` model.""" + + def __init__( + self, + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + ): + self.device = model.device + self.model = model + self.tokenizer = TransformerTokenizer(tokenizer) + + def forward( + self, + input_ids: "torch.LongTensor", + attention_mask: "torch.LongTensor", + past_key_values: Optional[Tuple] = None, + ) -> Tuple["torch.FloatTensor", Optional[KVCacheType]]: + """Compute a forward pass through the transformer model. + + Parameters + ---------- + input_ids + The input token ids. Must be one or two dimensional. + attention_mask + The attention mask. Must be one or two dimensional. + past_key_values + A tuple of tuples containing the cached key and value tensors for each + attention head. + + Returns + ------- + The computed logits and the new cached key and value tensors. + + """ + try: + import torch + except ImportError: + ImportError( + "The `torch` library needs to be installed to use `transformers` models." + ) + assert 0 < input_ids.ndim < 3 + + if past_key_values: + input_ids = input_ids[..., -1].unsqueeze(-1) + + with torch.inference_mode(): + output = self.model( + input_ids, + attention_mask=attention_mask, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + past_key_values=past_key_values, + ) + + return output.logits, output.past_key_values + + def __call__( + self, + input_ids: "torch.LongTensor", + attention_mask: "torch.LongTensor", + past_key_values: Optional[Tuple] = None, + ) -> "torch.FloatTensor": + logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values) + next_token_logits = logits[..., -1, :] + + return next_token_logits, kv_cache + + +def transformers( + model_name: str, + device: Optional[str] = None, + model_kwargs: dict = {}, + tokenizer_kwargs: dict = {}, +): + """Instantiate a model from the `transformers` library and its tokenizer. + + Parameters + ---------- + model_name + The name of the model as listed on Hugging Face's model page. + device + The device(s) on which the model should be loaded. This overrides + the `device_map` entry in `model_kwargs` when provided. + model_kwargs + A dictionary that contains the keyword arguments to pass to the + `from_pretrained` method when loading the model. + tokenizer_kwargs + A dictionary that contains the keyword arguments to pass to the + `from_pretrained` method when loading the tokenizer. + + Returns + ------- + A `TransformersModel` model instance. + + """ + try: + from transformers import AutoModelForCausalLM, AutoTokenizer + except ImportError: + raise ImportError( + "The `transformers` library needs to be installed in order to use `transformers` models." + ) + + if device is not None: + model_kwargs["device_map"] = device + + model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) + + tokenizer_kwargs.setdefault("padding_side", "left") + tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs) + + return Transformers(model, tokenizer) diff --git a/build/lib/outlines/models/vllm.py b/build/lib/outlines/models/vllm.py new file mode 100644 index 000000000..378a35d91 --- /dev/null +++ b/build/lib/outlines/models/vllm.py @@ -0,0 +1,159 @@ +import dataclasses +from typing import TYPE_CHECKING, List, Optional, Union + +from outlines.generate.api import GenerationParameters, SamplingParameters + +if TYPE_CHECKING: + from vllm import LLM + from vllm.sampling_params import SamplingParams + + +class VLLM: + """Represents a vLLM model. + + We wrap models from model providing libraries in order to give all of + them the same interface in Outlines and allow users to easily switch + between providers. This class wraps the `vllm.LLM` class from the + `vllm` library. + + """ + + def __init__(self, model: "LLM"): + self.model = model + self.lora_request = None + + def generate( + self, + prompts: Union[str, List[str]], + generation_parameters: GenerationParameters, + logits_processor, + sampling_parameters: SamplingParameters, + *, + sampling_params: Optional["SamplingParams"] = None, + ): + """Generate text using vLLM. + + Arguments + --------- + prompts + A prompt or list of prompts. + generation_parameters + An instance of `GenerationParameters` that contains the prompt, + the maximum number of tokens, stop sequences and seed. All the + arguments to `SequenceGeneratorAdapter`'s `__cal__` method. + logits_processor + The logits processor to use when generating text. + sampling_parameters + An instance of `SamplingParameters`, a dataclass that contains + the name of the sampler to use and related parameters as available + in Outlines. + samplng_params + An instance of `vllm.sampling_params.SamplingParams`. The values + passed via this dataclass supersede the values of the parameters + in `generation_parameters` and `sampling_parameters`. See the + vLLM documentation for more details: https://docs.vllm.ai/en/latest/dev/sampling_params.html. + + Returns + ------- + The generated text, of shape `(n_batch, n_samples)`. If there are only + one batch and several samples, the list is of shape `(n_samples)`. If + this is a batch with several sequences but only one sample the list is + of shape `(n_batch)`. If there is only one sequence and one sample, a + string is returned. + + """ + from vllm.sampling_params import SamplingParams + + if sampling_params is None: + sampling_params = SamplingParams() + + max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters) + + # We only update the values in `sampling_params` if they + # are specified by the user when calling the generator. + if max_tokens is not None: + sampling_params.max_tokens = max_tokens + if stop_at is not None: + if isinstance(stop_at, str): + stop_at = [stop_at] + sampling_params.stop = stop_at + if seed is not None: + sampling_params.seed = seed + + sampling_params.logits_processors = ( + [logits_processor] if logits_processor is not None else [] + ) + + sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple( + sampling_parameters + ) + + # We only update the values in `sampling_params` that + # were not specified by the user. + if sampling_params.n == 1: + sampling_params.n = num_samples + sampling_params.best_of = num_samples + if top_p is not None and sampling_params.top_p == 1.0: + sampling_params.top_p = top_p + if top_k is not None and sampling_params.top_k == -1: + sampling_params.top_k = top_k + if temperature is not None and sampling_params.temperature == 1.0: + sampling_params.temperature = temperature + if sampler == "beam_search": + sampling_params.use_beam_search = True + + results = self.model.generate( + prompts, sampling_params=sampling_params, lora_request=self.lora_request + ) + results = [[sample.text for sample in batch.outputs] for batch in results] + + batch_size = len(results) + sample_size = len(results[0]) + + if batch_size == 1 and sample_size == 1: + return results[0][0] + elif batch_size == 1: + return results[0] + elif sample_size == 1: + return [batch[0] for batch in results] + + return results + + def stream(self, *args, **kwargs): + """Return a text generator. + + Streaming is not yet available for `vllm.LLM`. + + TODO: Implement the streaming functionality ourselves. + + """ + raise NotImplementedError( + "Streaming is not available for the vLLM integration." + ) + + def load_lora(self, adapter_path: Optional[str]): + from vllm.lora.request import LoRARequest + + if adapter_path is None: + self.lora_request = None + else: + self.lora_request = LoRARequest(adapter_path, 1, adapter_path) + + +def vllm(model_name: str, **vllm_model_params): + """Load a vLLM model. + + Arguments + --------- + model_name + The name of the model to load from the HuggingFace hub. + vllm_model_params + vLLM-specific model parameters. See the vLLM code for the full list: + https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py + + """ + from vllm import LLM + + model = LLM(model_name, **vllm_model_params) + + return VLLM(model) diff --git a/build/lib/outlines/prompts.py b/build/lib/outlines/prompts.py new file mode 100644 index 000000000..01e900c96 --- /dev/null +++ b/build/lib/outlines/prompts.py @@ -0,0 +1,338 @@ +import functools +import inspect +import json +import re +import textwrap +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Type, cast + +from jinja2 import Environment, StrictUndefined +from pydantic import BaseModel + + +@dataclass +class Prompt: + """Represents a prompt function. + + We return a `Prompt` class instead of a simple function so the + template defined in prompt functions can be accessed. + + """ + + template: str + signature: inspect.Signature + + def __post_init__(self): + self.parameters: List[str] = list(self.signature.parameters.keys()) + + def __call__(self, *args, **kwargs) -> str: + """Render and return the template. + + Returns + ------- + The rendered template as a Python ``str``. + + """ + bound_arguments = self.signature.bind(*args, **kwargs) + bound_arguments.apply_defaults() + return render(self.template, **bound_arguments.arguments) + + def __str__(self): + return self.template + + +def prompt(fn: Callable) -> Prompt: + """Decorate a function that contains a prompt template. + + This allows to define prompts in the docstring of a function and simplify their + manipulation by providing some degree of encapsulation. It uses the `render` + function internally to render templates. + + >>> import outlines + >>> + >>> @outlines.prompt + >>> def build_prompt(question): + ... "I have a ${question}" + ... + >>> prompt = build_prompt("How are you?") + + This API can also be helpful in an "agent" context where parts of the prompt + are set when the agent is initialized and never modified later. In this situation + we can partially apply the prompt function at initialization. + + >>> import outlines + >>> import functools as ft + ... + >>> @outlines.prompt + ... def solve_task(name: str, objective: str, task: str): + ... '''Your name is {{name}}. + .. Your overall objective is to {{objective}}. + ... Please solve the following task: {{task}} + ... ''' + ... + >>> hal = ft.partial(solve_task, "HAL", "Travel to Jupiter") + + Returns + ------- + A `Prompt` callable class which will render the template when called. + + """ + + signature = inspect.signature(fn) + + # The docstring contains the template that will be rendered to be used + # as a prompt to the language model. + docstring = fn.__doc__ + if docstring is None: + raise TypeError("Could not find a template in the function's docstring.") + + template = cast(str, docstring) + + return Prompt(template, signature) + + +def render(template: str, **values: Optional[Dict[str, Any]]) -> str: + r"""Parse a Jinaj2 template and translate it into an Outlines graph. + + This function removes extra whitespaces and linebreaks from templates to + allow users to enter prompts more naturally than if they used Python's + constructs directly. See the examples for a detailed explanation. + + Examples + -------- + + Outlines follow Jinja2's syntax + + >>> import outlines + >>> outline = outlines.render("I like {{food}} and {{sport}}", food="tomatoes", sport="tennis") + I like tomatoes and tennis + + If the first line of the template is empty, `render` removes it + + >>> from outlines import render + >>> + >>> tpl = ''' + ... A new string''' + >>> tpl + ... '\nA new string' + >>> render(tpl) + ... 'a new string' + + Similarly, `render` ignores linebreaks introduced by placing the closing quotes + underneath the text: + + >>> tpl = ''' + ... A new string + ... ''' + >>> tpl + ... '\nA new string\n' + >>> render(tpl) + ... 'A new string' + + If you want to insert a linebreak at the end of the rendered template, you will + need to leave an empty line at the end of the template: + + >>> tpl = ''' + ... A new string + ... + ... ''' + >>> tpl + ... '\nA new string\n\n' + >>> render(tpl) + ... 'A new string\n' + + `render` removes the identation in docstrings. This is particularly important + when using prompt functions + + >>> tpl = ''' + ... a string + ... and another string''' + >>> tpl + ... '\n a string\n and another string' + >>> render(tpl) + ... 'a string\nand another string' + + The indentation of the first line is assumed to be the same as the second line's + + >>> tpl = '''a string + ... and another''' + >>> tpl + ... 'a string\n and another' + >>> render(tpl) + ... 'a string\nand another' + + To get a different indentation for the first and the second line, we can start the + prompt on the string's second line: + + >>> tpl = ''' + ... First line + ... Second line''' + >>> render(tpl) + ... 'First Line\n Second Line' + + Parameters + ---------- + template + A string that contains a template written with the Jinja2 syntax. + **values + Map from the variables in the template to their value. + + Returns + ------- + A string that contains the rendered template. + + """ + # Dedent, and remove extra linebreak + cleaned_template = inspect.cleandoc(template) + + # Add linebreak if there were any extra linebreaks that + # `cleandoc` would have removed + ends_with_linebreak = template.replace(" ", "").endswith("\n\n") + if ends_with_linebreak: + cleaned_template += "\n" + + # Remove extra whitespaces, except those that immediately follow a newline symbol. + # This is necessary to avoid introducing whitespaces after backslash `\` characters + # used to continue to the next line without linebreak. + cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template) + + env = Environment( + trim_blocks=True, + lstrip_blocks=True, + keep_trailing_newline=True, + undefined=StrictUndefined, + ) + env.filters["name"] = get_fn_name + env.filters["description"] = get_fn_description + env.filters["source"] = get_fn_source + env.filters["signature"] = get_fn_signature + env.filters["schema"] = get_schema + env.filters["args"] = get_fn_args + + jinja_template = env.from_string(cleaned_template) + + return jinja_template.render(**values) + + +def get_fn_name(fn: Callable): + """Returns the name of a callable.""" + if not callable(fn): + raise TypeError("The `name` filter only applies to callables.") + + if not hasattr(fn, "__name__"): + name = type(fn).__name__ + else: + name = fn.__name__ + + return name + + +def get_fn_args(fn: Callable): + """Returns the arguments of a function with annotations and default values if provided.""" + if not callable(fn): + raise TypeError("The `args` filter only applies to callables.") + + arg_str_list = [] + signature = inspect.signature(fn) + arg_str_list = [str(param) for param in signature.parameters.values()] + arg_str = ", ".join(arg_str_list) + return arg_str + + +def get_fn_description(fn: Callable): + """Returns the first line of a callable's docstring.""" + if not callable(fn): + raise TypeError("The `description` filter only applies to callables.") + + docstring = inspect.getdoc(fn) + if docstring is None: + description = "" + else: + description = docstring.split("\n")[0].strip() + + return description + + +def get_fn_source(fn: Callable): + """Return the source code of a callable.""" + if not callable(fn): + raise TypeError("The `source` filter only applies to callables.") + + source = textwrap.dedent(inspect.getsource(fn)) + re_search = re.search(re.compile(r"(\bdef\b.*)", re.DOTALL), source) + if re_search is not None: + source = re_search.group(0) + else: + raise TypeError("Could not read the function's source code") + + return source + + +def get_fn_signature(fn: Callable): + """Return the signature of a callable.""" + if not callable(fn): + raise TypeError("The `source` filter only applies to callables.") + + source = textwrap.dedent(inspect.getsource(fn)) + re_search = re.search(re.compile(r"\(([^)]+)\)"), source) + if re_search is None: + signature = "" + else: + signature = re_search.group(1) + + return signature + + +@functools.singledispatch +def get_schema(model: Any): + raise NotImplementedError( + f"No schema rendering function defined for type {type(model)}." + ) + + +@get_schema.register(dict) +def get_schema_dict(model: Dict): + """Return a pretty-printed dictionary""" + return json.dumps(model, indent=2) + + +@get_schema.register(type(BaseModel)) +def get_schema_pydantic(model: Type[BaseModel]): + """Return the schema of a Pydantic model.""" + if not type(model) == type(BaseModel): + raise TypeError("The `schema` filter only applies to Pydantic models.") + + if hasattr(model, "model_json_schema"): + def_key = "$defs" + raw_schema = model.model_json_schema() + else: # pragma: no cover + def_key = "definitions" + raw_schema = model.schema() + + definitions = raw_schema.get(def_key, None) + schema = parse_pydantic_schema(raw_schema, definitions) + + return json.dumps(schema, indent=2) + + +def parse_pydantic_schema(raw_schema, definitions): + """Parse the output of `Basemodel.[schema|model_json_schema]()`. + + This recursively follows the references to other schemas in case + of nested models. Other schemas are stored under the "definitions" + key in the schema of the top-level model. + + """ + simple_schema = {} + for name, value in raw_schema["properties"].items(): + if "description" in value: + simple_schema[name] = value["description"] + elif "$ref" in value: + refs = value["$ref"].split("/") + simple_schema[name] = parse_pydantic_schema( + definitions[refs[2]], definitions + ) + else: + simple_schema[name] = f"<{name}>" + + return simple_schema diff --git a/build/lib/outlines/py.typed b/build/lib/outlines/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/build/lib/outlines/samplers.py b/build/lib/outlines/samplers.py new file mode 100644 index 000000000..8b64ed768 --- /dev/null +++ b/build/lib/outlines/samplers.py @@ -0,0 +1,324 @@ +import math +from typing import TYPE_CHECKING, Callable, Optional, Protocol, Tuple + +if TYPE_CHECKING: + import torch + + +class Sampler(Protocol): + samples: int + + def __call__( + self, + next_token_logits: "torch.DoubleTensor", + sequence_weights: "torch.DoubleTensor", + rng: "torch.Generator", + ) -> "torch.DoubleTensor": + ... + + +class GreedySampler: + """Greedy Sampling algorithm. + + Greedy sampling consists in choosing the token with the largest + likelihood at every step. + + We don't allow more than one sample. We could attribute this a meaning, for + instance the k-th sample represents the k-th most likely token. In which + case it would be equivalent to beam search without the sequence weights. + + Attributes + ---------- + samples + The number of samples taken for each input sequence. + + """ + + def __init__(self): + self.samples = 1 + + def __call__( + self, + next_token_logits: "torch.DoubleTensor", + sequence_weights: "torch.DoubleTensor", + _, + ) -> "torch.DoubleTensor": + """Call the greedy sampler. + + Parameters + ---------- + next_token_logits + A tensor of shape ``(n_seqs, vocab_size,)`` that represents the + probability distribution of the next token over the vocabulary. + sequence_weights + A tensor of shape ``(n_seqs,)`` that represents the cumulative + weight of each sequence. + rng + A random number generator. + + Returns + ------- + A tuple with an array that contains the ids of the sampled tokens of + shape ``(n_seqs, 1)``, an array that contains the ancestors of each + sampled id of shape ``(n_seqs,)`` and an array that contains the updated + cumulative weights of each sequence of shape ``(n_seqs,)``. + + """ + import torch + + logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) + next_token_ids = torch.argmax(logprobs, dim=-1, keepdim=True) + + ancestors = torch.arange( + next_token_logits.shape[0], device=next_token_logits.device + ) + weights = sequence_weights + torch.gather(logprobs, 1, next_token_ids).squeeze() + + return next_token_ids, ancestors, weights + + +greedy = GreedySampler + + +class MultinomialSampler: + """Multinomial sampling algorithm. + + Multinomial sampling consists in randomly sampling the next token assuming + its distribution is a Categorical distribution parametrized by the + next-token logits. + + + Attributes + ---------- + samples + The number of samples taken for each input sequence. + + """ + + def __init__( + self, + samples: int = 1, + *, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + ): + self.samples = samples + self.top_k = top_k + self.top_p = top_p + self.temperature = temperature + + self.logits_processors = [] + if top_k is not None: + self.logits_processors.append(keep_top_k_logits(top_k)) + elif top_p is not None: + self.logits_processors.append(keep_top_p_logits(top_p)) + + if temperature is not None: + self.logits_processors.append(rescale_logits(temperature)) + + def __call__( + self, + next_token_logits: "torch.DoubleTensor", + sequence_weights: "torch.DoubleTensor", + rng: "torch.Generator", + ) -> Tuple["torch.DoubleTensor", "torch.DoubleTensor", "torch.DoubleTensor"]: + """Call the multinomial sampler. + + Parameters + ---------- + next_token_logits + A tensor of shape ``(n_seqs, vocab_size,)`` that represents the + probability distribution of the next token over the vocabulary. + sequence_weights + A tensor of shape ``(n_seqs,)`` that represents the cumulative + weight of each sequence. + rng + A random number generator. + + Returns + ------- + A tuple with an array that contains the ids of the sampled tokens of + shape ``(n_seqs, 1)``, an array that contains the ancestors of each + sampled id of shape ``(n_seqs,)`` and an array that contains the updated + cumulative weights of each sequence of shape ``(n_seqs,)``. + + """ + import torch + + altered_next_token_logits = next_token_logits + for logit_processor in self.logits_processors: + altered_next_token_logits = logit_processor(next_token_logits) + + probs = torch.nn.functional.softmax(altered_next_token_logits, dim=-1) + next_token_ids = torch.multinomial(probs, num_samples=1, generator=rng) + + logprobs = torch.nn.functional.log_softmax(altered_next_token_logits, dim=-1) + ancestors = torch.arange( + altered_next_token_logits.shape[0], device=next_token_logits.device + ) + weights = sequence_weights + torch.gather(logprobs, 1, next_token_ids).squeeze() + + return next_token_ids, ancestors, weights + + +multinomial = MultinomialSampler + + +def keep_top_k_logits(k: int) -> Callable[["torch.Tensor"], "torch.Tensor"]: + """Build a function that masks logits values smaller than the top `k` ones. + + Parameters + ---------- + k + The ranking below which logit values are replaced by `-math.inf`. + + """ + import torch + + if not isinstance(k, int) or k < 1: + raise ValueError(f"`k` must be a strictly positive integers, got {k} instead.") + + def logits_processor(logits: torch.Tensor) -> torch.Tensor: + num_to_keep = min(k, logits.size(-1)) + mask_idx = logits < torch.topk(logits, num_to_keep)[0][..., -1, None] + return logits.masked_fill(mask_idx, -math.inf) + + return logits_processor + + +def keep_top_p_logits(p: float) -> Callable[["torch.Tensor"], "torch.Tensor"]: + """Build a function that masks the lowest probability tokens whose + cumulative probability is below a certain threshold. + + Parameters + ---------- + p + The value of the threshold. We keep the highest probability tokens whose + cumulative distribution is greater than or equal to `p` and mask the + others. Its value must be between 0 (excluded) and 1 (included). + + """ + import torch + + if p <= 0.0 or p > 1.0: + raise ValueError( + f"`p` must be a floating point number between 0 (excluded) and 1 (included), got {p} instead." + ) + + def logits_processor(logits: torch.Tensor) -> torch.Tensor: + sorted_logits, sorted_idx = torch.sort(logits, descending=False) + cumulative_probabilties = torch.nn.functional.softmax( + sorted_logits, dim=-1 + ).cumsum(dim=-1) + + sorted_masked_idx = cumulative_probabilties <= (1 - p) + mask_idx = torch.scatter(sorted_masked_idx, 1, sorted_idx, sorted_masked_idx) + return logits.masked_fill(mask_idx, -math.inf) + + return logits_processor + + +def rescale_logits(temperature: float) -> Callable[["torch.Tensor"], "torch.Tensor"]: + """Build a function that rescales the token probabilities exponentially. + + Parameters + ---------- + temperature + The value by which we rescale the logits. + + """ + + if not isinstance(temperature, float) or temperature < 0.0: + raise ValueError( + f"`temperature` must be a strictly positive floating point number, got {temperature} instead." + ) + elif temperature == 0.0: + raise ValueError( + "Please use the greedy sampler instead of setting the temperature to 0." + ) + + def logits_processor(logits: "torch.Tensor") -> "torch.Tensor": + return logits / temperature + + return logits_processor + + +class BeamSearchSampler: + """Beam Search sampling algorithm. + + Attributes + ---------- + samples + The number of samples taken for each input sequence. + + """ + + def __init__(self, beams: int = 1): + self.samples = beams + + def __call__( + self, + next_token_logits: "torch.DoubleTensor", + sequence_weights: "torch.DoubleTensor", + _, + ) -> Tuple["torch.DoubleTensor", "torch.DoubleTensor", "torch.DoubleTensor"]: + """Call the beam search sampler. + + Parameters + ---------- + next_token_logits + A tensor of shape ``(n_seqs, vocab_size,)`` that represents the + probability distribution of the next token over the vocabulary. + sequence_weights + A tensor of shape ``(n_seqs,)`` that represents the cumulative + weight of each sequence. + rng + A random number generator. + + Returns + ------- + A tuple with an array that contains the ids of the sampled tokens of + shape ``(n_seqs, 1)``, an array that contains the ancestors of each + sampled id of shape ``(n_seqs,)`` and an array that contains the updated + cumulative weights of each sequence of shape ``(n_seqs,)``. + + """ + import torch + + logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) + weights = logprobs + sequence_weights.unsqueeze(1).expand_as(next_token_logits) + + # Flatten scores to (n_batch, n_samples * vocab_size) + # and find the top-k weights for each batch. + batch_size = next_token_logits.shape[0] // self.samples + vocab_size = next_token_logits.shape[-1] + weights = weights.view(batch_size, self.samples * vocab_size) + + # If the weights are all equal to 0 we are at the beginning of the search + # and thus only need to sample from one set of token logits for each + # batch. + if torch.all(sequence_weights == 0): + weights = weights[:, :vocab_size] + + weights, indices = torch.topk( + weights, self.samples, dim=1, largest=True, sorted=True + ) + + ancestors = torch.div(indices, vocab_size, rounding_mode="floor") + next_token_ids = indices % vocab_size + + # Re-shape the weights, next_token_ids and ancestors to (n_batch * n_samples, 1) + first_batch_idx = torch.arange( + 0, batch_size * self.samples, self.samples, device=next_token_logits.device + ).unsqueeze(1) + ancestors = ancestors + first_batch_idx + + ancestors = ancestors.view(self.samples * batch_size) + weights = weights.view(self.samples * batch_size) + next_token_ids = next_token_ids.view(self.samples * batch_size, 1) + + return next_token_ids, ancestors, weights + + +beam_search = BeamSearchSampler diff --git a/build/lib/outlines/serve/__init__.py b/build/lib/outlines/serve/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/build/lib/outlines/serve/serve.py b/build/lib/outlines/serve/serve.py new file mode 100644 index 000000000..fb8c80139 --- /dev/null +++ b/build/lib/outlines/serve/serve.py @@ -0,0 +1,136 @@ +# _______________________________ +# / Don't want to self-host? \ +# \ Try .json at http://dottxt.co / +# ------------------------------- +# \ ^__^ +# \ (oo)\_______ +# (__)\ )\/\ +# ||----w | +# || || +# +# +# Copyright 2024- the Outlines developers +# Copyright 2023 the vLLM developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +from typing import AsyncGenerator + +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.sampling_params import SamplingParams +from vllm.utils import random_uuid + +from outlines.integrations.vllm import JSONLogitsProcessor, RegexLogitsProcessor + +TIMEOUT_KEEP_ALIVE = 5 # seconds. +TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds. +app = FastAPI() +engine = None + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.post("/generate") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - schema: the JSON schema to use for the generation (if regex is not provided). + - regex: the regex to use for the generation (if schema is not provided). + - stream: whether to stream the results or not. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ + assert engine is not None + + request_dict = await request.json() + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", False) + + json_schema = request_dict.pop("schema", None) + regex_string = request_dict.pop("regex", None) + if json_schema is not None: + logits_processors = [JSONLogitsProcessor(json_schema, engine.engine)] + elif regex_string is not None: + logits_processors = [RegexLogitsProcessor(regex_string, engine.engine)] + else: + logits_processors = [] + + sampling_params = SamplingParams( + **request_dict, logits_processors=logits_processors # type: ignore + ) + request_id = random_uuid() + + results_generator = engine.generate(prompt, sampling_params, request_id) # type: ignore + + # Streaming case + async def stream_results() -> AsyncGenerator[bytes, None]: + async for request_output in results_generator: + prompt = request_output.prompt + text_outputs = [prompt + output.text for output in request_output.outputs] + ret = {"text": text_outputs} + yield (json.dumps(ret) + "\0").encode("utf-8") + + if stream: + return StreamingResponse(stream_results()) + + # Non-streaming case + final_output = None + async for request_output in results_generator: + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await engine.abort(request_id) # type: ignore + return Response(status_code=499) + final_output = request_output + + assert final_output is not None + prompt = final_output.prompt + text_outputs = [prompt + output.text for output in final_output.outputs] + ret = {"text": text_outputs} + return JSONResponse(ret) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + # Adds the `engine_use_ray`, `disable_log_requests` and `max_log_len` + # arguments + engine_args: AsyncEngineArgs = AsyncEngineArgs.from_cli_args(args) # type: ignore + + # Sets default for the model (`facebook/opt-125m`) + engine = AsyncLLMEngine.from_engine_args(engine_args) + + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="debug", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ) diff --git a/build/lib/outlines/serve/vllm.py b/build/lib/outlines/serve/vllm.py new file mode 100644 index 000000000..ddc50b47d --- /dev/null +++ b/build/lib/outlines/serve/vllm.py @@ -0,0 +1,4 @@ +from outlines.integrations.vllm import ( # noqa[F401] + JSONLogitsProcessor, + RegexLogitsProcessor, +) diff --git a/build/lib/outlines/types/__init__.py b/build/lib/outlines/types/__init__.py new file mode 100644 index 000000000..f4d2b8cd3 --- /dev/null +++ b/build/lib/outlines/types/__init__.py @@ -0,0 +1,4 @@ +from . import airports, countries +from .email import Email +from .isbn import ISBN +from .locales import locale diff --git a/build/lib/outlines/types/airports.py b/build/lib/outlines/types/airports.py new file mode 100644 index 000000000..560da6bd9 --- /dev/null +++ b/build/lib/outlines/types/airports.py @@ -0,0 +1,16 @@ +"""Generate valid airport codes.""" +from enum import Enum + +try: + from pyairports.airports import AIRPORT_LIST +except ImportError: + raise ImportError( + 'The `airports` module requires "pyairports" to be installed. You can install it with "pip install pyairports"' + ) + + +AIRPORT_IATA_LIST = list( + {(airport[3], airport[3]) for airport in AIRPORT_LIST if airport[3] != ""} +) + +IATA = Enum("Airport", AIRPORT_IATA_LIST) # type:ignore diff --git a/build/lib/outlines/types/countries.py b/build/lib/outlines/types/countries.py new file mode 100644 index 000000000..888443dc6 --- /dev/null +++ b/build/lib/outlines/types/countries.py @@ -0,0 +1,24 @@ +"""Generate valid country codes and names.""" +from enum import Enum + +try: + import pycountry +except ImportError: + raise ImportError( + 'The `countries` module requires "pycountry" to be installed. You can install it with "pip install pycountry"' + ) + +ALPHA_2_CODE = [(country.alpha_2, country.alpha_2) for country in pycountry.countries] +Alpha2 = Enum("Alpha_2", ALPHA_2_CODE) # type:ignore + +ALPHA_3_CODE = [(country.alpha_3, country.alpha_3) for country in pycountry.countries] +Alpha3 = Enum("Alpha_2", ALPHA_3_CODE) # type:ignore + +NUMERIC_CODE = [(country.numeric, country.numeric) for country in pycountry.countries] +Numeric = Enum("Numeric_code", NUMERIC_CODE) # type:ignore + +NAME = [(country.name, country.name) for country in pycountry.countries] +Name = Enum("Name", NAME) # type:ignore + +FLAG = [(country.flag, country.flag) for country in pycountry.countries] +Flag = Enum("Flag", FLAG) # type:ignore diff --git a/build/lib/outlines/types/email.py b/build/lib/outlines/types/email.py new file mode 100644 index 000000000..45f8c4b2c --- /dev/null +++ b/build/lib/outlines/types/email.py @@ -0,0 +1,11 @@ +"""Email Address types.""" +from pydantic import WithJsonSchema +from typing_extensions import Annotated + +# Taken from StackOverflow +# https://stackoverflow.com/a/201378/14773537 +EMAIL_REGEX = r"""(?:[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*|"(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21\x23-\x5b\x5d-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])*")@(?:(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?|\[(?:(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9]))\.){3}(?:(2(5[0-5]|[0-4][0-9])|1[0-9][0-9]|[1-9]?[0-9])|[a-z0-9-]*[a-z0-9]:(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21-\x5a\x53-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])+)\])""" +Email = Annotated[ + str, + WithJsonSchema({"type": "string", "pattern": EMAIL_REGEX}), +] diff --git a/build/lib/outlines/types/isbn.py b/build/lib/outlines/types/isbn.py new file mode 100644 index 000000000..5aebb067e --- /dev/null +++ b/build/lib/outlines/types/isbn.py @@ -0,0 +1,12 @@ +"""ISBN type""" +from pydantic import WithJsonSchema +from typing_extensions import Annotated + +# Matches any ISBN number. Note that this is not completely correct as not all +# 10 or 13 digits numbers are valid ISBNs. See https://en.wikipedia.org/wiki/ISBN +# Taken from O'Reilly's Regular Expression Cookbook: +# https://www.oreilly.com/library/view/regular-expressions-cookbook/9781449327453/ch04s13.html +# TODO: Can this be represented by a grammar or do we need semantic checks? +ISBN_REGEX = r"(?:ISBN(?:-1[03])?:? )?(?=[0-9X]{10}$|(?=(?:[0-9]+[- ]){3})[- 0-9X]{13}$|97[89][0-9]{10}$|(?=(?:[0-9]+[- ]){4})[- 0-9]{17}$)(?:97[89][- ]?)?[0-9]{1,5}[- ]?[0-9]+[- ]?[0-9]+[- ]?[0-9X]" + +ISBN = Annotated[str, WithJsonSchema({"type": "string", "pattern": ISBN_REGEX})] diff --git a/build/lib/outlines/types/locales.py b/build/lib/outlines/types/locales.py new file mode 100644 index 000000000..c5d251bae --- /dev/null +++ b/build/lib/outlines/types/locales.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass + +from outlines.types.phone_numbers import USPhoneNumber +from outlines.types.zip_codes import USZipCode + + +@dataclass +class US: + ZipCode = USZipCode + PhoneNumber = USPhoneNumber + + +def locale(locale_str: str): + locales = {"us": US} + + if locale_str not in locales: + raise NotImplementedError( + f"The locale {locale_str} is not supported yet. Please don't hesitate to create custom types for you locale and open a Pull Request." + ) + + return locales[locale_str] diff --git a/build/lib/outlines/types/phone_numbers.py b/build/lib/outlines/types/phone_numbers.py new file mode 100644 index 000000000..618687e75 --- /dev/null +++ b/build/lib/outlines/types/phone_numbers.py @@ -0,0 +1,16 @@ +"""Phone number types. + +We currently only support US phone numbers. We can however imagine having custom types +for each country, for instance leveraging the `phonenumbers` library. + +""" +from pydantic import WithJsonSchema +from typing_extensions import Annotated + +US_PHONE_NUMBER = r"(\([0-9]{3}\) |[0-9]{3}-)[0-9]{3}-[0-9]{4}" + + +USPhoneNumber = Annotated[ + str, + WithJsonSchema({"type": "string", "pattern": US_PHONE_NUMBER}), +] diff --git a/build/lib/outlines/types/zip_codes.py b/build/lib/outlines/types/zip_codes.py new file mode 100644 index 000000000..67d994d5c --- /dev/null +++ b/build/lib/outlines/types/zip_codes.py @@ -0,0 +1,13 @@ +"""Zip code types. + +We currently only support US Zip Codes. + +""" +from pydantic import WithJsonSchema +from typing_extensions import Annotated + +# This matches Zip and Zip+4 codes +US_ZIP_CODE = r"\d{5}(?:-\d{4})?" + + +USZipCode = Annotated[str, WithJsonSchema({"type": "string", "pattern": US_ZIP_CODE})] diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 810ef5910..b91f42536 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -10,7 +10,7 @@ from referencing._core import Resolver from referencing.jsonschema import DRAFT202012 -STRING_INNER = r'([^"\\\x00-\x1f\x7f-\x9f]|\\\\)' +STRING_INNER = r'([^"\\\x00-\x1f\x7f-\x9f]|\\\S)' STRING = f'"{STRING_INNER}*"' INTEGER = r"(-)?(0|[1-9][0-9]*)" NUMBER = rf"({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?" diff --git a/outlines/generate/json.py b/outlines/generate/json.py index 3837f72b6..985edf3bb 100644 --- a/outlines/generate/json.py +++ b/outlines/generate/json.py @@ -47,7 +47,7 @@ def json( schema = pyjson.dumps(schema_object.model_json_schema()) regex_str = build_regex_from_schema(schema, whitespace_pattern) generator = regex(model, regex_str, sampler) - generator.format_sequence = lambda x: schema_object.parse_raw(x) + generator.format_sequence = lambda x: schema_object.model_validate_json(x) elif callable(schema_object): schema = pyjson.dumps(get_schema_from_signature(schema_object)) regex_str = build_regex_from_schema(schema, whitespace_pattern) diff --git a/outlines/integrations/vllm.py b/outlines/integrations/vllm.py index f5d90dc6a..805897032 100644 --- a/outlines/integrations/vllm.py +++ b/outlines/integrations/vllm.py @@ -34,6 +34,9 @@ from outlines.fsm.guide import RegexGuide, Write, Generate from outlines.fsm.json_schema import build_regex_from_schema from outlines.integrations.utils import adapt_tokenizer, convert_json_schema_to_str +import logging + +from tqdm import tqdm if TYPE_CHECKING: from vllm import LLM @@ -77,9 +80,9 @@ def __init__(self, regex_string: str, llm: "LLM"): "`tokenizer` attribute or a `get_tokenizer` method." ) tokenizer = adapt_tokenizer(tokenizer=tokenizer) - self.mask_cache = {} self.fsm = RegexGuide(regex_string, tokenizer) self._fsm_state: DefaultDict[int, int] = defaultdict(int) + self.mask_cache = {} def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token. @@ -96,6 +99,18 @@ def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: torch.Tensor The biased logits. """ + + if self.mask_cache is None: + self.mask_cache = torch.full((max(self.fsm.states_to_token_maps.keys()) + 2, + scores.shape[-1]), -math.inf, + dtype=scores.dtype) + # fill mask_cache (keys should be from 0 to num_states) + for key in tqdm(self.fsm.states_to_token_maps.keys(), desc="Init mask cache"): + allowed_tokens = self.fsm.get_next_instruction(key).tokens + self.mask_cache[key][allowed_tokens] = 0 + self.mask_cache = self.mask_cache.pin_memory() + #self.mask_cache = self.mask_cache.to(scores.device, non_blocking=True) + seq_id = hash(tuple(input_ids)) # Initialize the FSM state dictionary if the input_ids are empty, as this means @@ -108,20 +123,20 @@ def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: ) state = self._fsm_state[seq_id] - if state in self.mask_cache: mask = self.mask_cache[state] else: allowed_tokens = self.fsm.get_next_instruction( state=self._fsm_state[seq_id] ).tokens - mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device) + mask = torch.full((scores.shape[-1],), -math.inf).pin_memory() mask[allowed_tokens] = 0 self.mask_cache[state] = mask - biased_scores = scores + mask + mask = self.mask_cache[state].to(device=scores.device, non_blocking=True) + #biased_scores = scores + mask - return biased_scores + return scores.add(mask) class JSONLogitsProcessor(RegexLogitsProcessor): @@ -134,10 +149,10 @@ class JSONLogitsProcessor(RegexLogitsProcessor): """ def __init__( - self, - schema: Union[dict, Type[BaseModel], str], - llm: "LLM", - whitespace_pattern: Optional[str] = None, + self, + schema: Union[dict, Type[BaseModel], str], + llm: "LLM", + whitespace_pattern: Optional[str] = None, ): """Compile the FSM that drives the JSON-guided generation.