From 1ec8bb59a41b0ba00d86e96231947eeae31e8a61 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 13 Jun 2024 13:49:22 -0500 Subject: [PATCH] add outlines.models.mlxlm --- docs/reference/models/mlxlm.md | 32 +++ mkdocs.yml | 1 + outlines/generate/cfg.py | 8 +- outlines/generate/regex.py | 13 + outlines/generate/text.py | 7 +- outlines/models/__init__.py | 1 + outlines/models/mlxlm.py | 240 +++++++++++++++++++ outlines/processors/__init__.py | 7 + outlines/processors/base_logits_processor.py | 78 ++++++ outlines/processors/structured.py | 187 +++++++++++++++ pyproject.toml | 3 + tests/generate/conftest.py | 18 ++ tests/generate/test_generate.py | 54 +++++ 13 files changed, 645 insertions(+), 4 deletions(-) create mode 100644 docs/reference/models/mlxlm.md create mode 100644 outlines/models/mlxlm.py create mode 100644 outlines/processors/__init__.py create mode 100644 outlines/processors/base_logits_processor.py create mode 100644 outlines/processors/structured.py create mode 100644 tests/generate/test_generate.py diff --git a/docs/reference/models/mlxlm.md b/docs/reference/models/mlxlm.md new file mode 100644 index 000000000..539e03851 --- /dev/null +++ b/docs/reference/models/mlxlm.md @@ -0,0 +1,32 @@ +# mlx-lm + +Outlines provides an integration with [mlx-lm](https://github.com/ml-explore/mlx-examples/tree/main/llms), allowing models to be run quickly on Apple Silicon via the [mlx](https://ml-explore.github.io/mlx/build/html/index.html) library. + +## Installation + +In addition to `outlines`, you must install `mlx-lm` and `mlx` libraries. You must use a device which [supports Metal](https://support.apple.com/en-us/102894). + +## Using `models.mlxlm` + +```python +from outlines import models + +model = models.mlxlm("mlx-community/mlx-community/Meta-Llama-3-8B-Instruct-8bit") +``` + +With the loaded model, you can generate text or perform structured generation, e.g. + +```python3 +from outlines import models, generate + +model = models.mlxlm("mlx-community/Meta-Llama-3-8B-Instruct-8bit") + +phone_number_pattern = "\\+?[1-9][0-9]{7,14}" +generator = generate.regex(model, phone_number_pattern) + +model_output = generator("What's Jennys Number?\n") +print(model_output) +# '8675309' +``` + +For more examples, see the [cookbook](cookbook/index.md). diff --git a/mkdocs.yml b/mkdocs.yml index 01e8506ab..e967d183f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -124,6 +124,7 @@ nav: - vLLM: reference/models/vllm.md - Llama.cpp: reference/models/llamacpp.md - Transformers: reference/models/transformers.md + - MLX: reference/models/mlxlm.md - ExllamaV2: reference/models/exllamav2.md - Mamba: reference/models/mamba.md - OpenAI: reference/models/openai.md diff --git a/outlines/generate/cfg.py b/outlines/generate/cfg.py index 0a6698b08..e473c26a6 100644 --- a/outlines/generate/cfg.py +++ b/outlines/generate/cfg.py @@ -4,6 +4,7 @@ from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter from outlines.models import OpenAI from outlines.models.llamacpp import LlamaCpp +from outlines.models.mlxlm import MLXLM from outlines.models.vllm import VLLM from outlines.samplers import Sampler, multinomial @@ -33,14 +34,15 @@ def cfg(model, cfg_str: str, sampler: Sampler = multinomial()) -> SequenceGenera return generator +@cfg.register(MLXLM) @cfg.register(VLLM) -def cfg_vllm( - model: VLLM, +def cfg_unimplemented( + model, cfg_str: str, sampler: Sampler = multinomial(), ): raise NotImplementedError( - "The CFG Logits processor is not available for the vLLM integration." + f"The CFG Logits processor is not available for {type(model)}." ) diff --git a/outlines/generate/regex.py b/outlines/generate/regex.py index ceea5d994..6b6656fe9 100644 --- a/outlines/generate/regex.py +++ b/outlines/generate/regex.py @@ -4,6 +4,7 @@ from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter from outlines.models import OpenAI from outlines.models.llamacpp import LlamaCpp +from outlines.models.mlxlm import MLXLM from outlines.models.vllm import VLLM from outlines.samplers import Sampler, multinomial @@ -37,6 +38,18 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()): return generator +@regex.register(MLXLM) +def regex_mlxlm( + model: MLXLM, + regex_str: str, + sampler: Sampler = multinomial(), +): + from outlines.processors import RegexLogitsProcessor + + logits_processor = RegexLogitsProcessor(regex_str, tokenizer=model.tokenizer) + return SequenceGeneratorAdapter(model, logits_processor, sampler) + + @regex.register(LlamaCpp) def regex_llamacpp( model: LlamaCpp, diff --git a/outlines/generate/text.py b/outlines/generate/text.py index 35031348d..081ba0920 100644 --- a/outlines/generate/text.py +++ b/outlines/generate/text.py @@ -2,7 +2,7 @@ from outlines.fsm.guide import StopAtEOSGuide from outlines.generate.api import SequenceGenerator, SequenceGeneratorAdapter -from outlines.models import VLLM, LlamaCpp, OpenAI +from outlines.models import MLXLM, VLLM, LlamaCpp, OpenAI from outlines.samplers import Sampler, multinomial @@ -36,6 +36,11 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGenerator: return generator +@text.register(MLXLM) +def text_mlxlm(model: MLXLM, sampler: Sampler = multinomial()): + return SequenceGeneratorAdapter(model, None, sampler) + + @text.register(VLLM) def text_vllm(model: VLLM, sampler: Sampler = multinomial()): return SequenceGeneratorAdapter(model, None, sampler) diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index 3676e6ccc..fb18824b3 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -10,6 +10,7 @@ from .exllamav2 import ExLlamaV2Model, exl2 from .llamacpp import LlamaCpp, llamacpp from .mamba import Mamba, mamba +from .mlxlm import MLXLM, mlxlm from .openai import OpenAI, azure_openai, openai from .transformers import Transformers, transformers from .vllm import VLLM, vllm diff --git a/outlines/models/mlxlm.py b/outlines/models/mlxlm.py new file mode 100644 index 000000000..f561f269d --- /dev/null +++ b/outlines/models/mlxlm.py @@ -0,0 +1,240 @@ +import dataclasses +from typing import TYPE_CHECKING, Generator, Iterator, List, Optional, Tuple, Union + +from .transformers import TransformerTokenizer + +if TYPE_CHECKING: + import mlx.core as mx + import mlx.nn as nn + from transformers import PreTrainedTokenizer + + from outlines.generate.api import GenerationParameters, SamplingParameters + from outlines.processors import BaseLogitsProcessor + + +class MLXLM: + """ + Represents an `mlx_lm` model + """ + + def __init__( + self, + model: "nn.Module", + tokenizer: "PreTrainedTokenizer", + ): + self.model = model + self.mlx_tokenizer = tokenizer # returns mlx tensors, used for encode() + self.tokenizer = TransformerTokenizer( + tokenizer._tokenizer + ) # _tokenizer is HF Tokenizer + + def generate( + self, + prompts: Union[str, List[str]], + generation_parameters: "GenerationParameters", + logits_processor, + sampling_parameters: "SamplingParameters", + ) -> str: + streamer = self.stream( + prompts, generation_parameters, logits_processor, sampling_parameters + ) + return "".join(list(streamer)) + + def stream( + self, + prompts: Union[str, List[str]], + generation_parameters: "GenerationParameters", + logits_processor, + sampling_parameters: "SamplingParameters", + ) -> Iterator[str]: + """Generate text using `mlx_lm`. + + Arguments + --------- + prompts + A prompt or list of prompts. + generation_parameters + An instance of `GenerationParameters` that contains the prompt, + the maximum number of tokens, stop sequences and seed. All the + arguments to `SequenceGeneratorAdapter`'s `__cal__` method. + logits_processor + The logits processor to use when generating text. + sampling_parameters + An instance of `SamplingParameters`, a dataclass that contains + the name of the sampler to use and related parameters as available + in Outlines. + Returns + ------- + The generated text. + """ + import mlx.core as mx + + max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters) + sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple( + sampling_parameters + ) + if max_tokens is None: + max_tokens = int(1e9) + + if not isinstance(prompts, str): + raise NotImplementedError( + "The `mlx-lm` library does not support batch inference." + ) + if sampler == "beam_search": + raise NotImplementedError( + "The `mlx-lm` library does not support Beam Search." + ) + if num_samples != 1: + raise NotImplementedError( + "The `mlx-lm` library does not allow to take several samples." + ) + if top_k is not None: + raise NotImplementedError("The `mlx-lm` library does not support top_k.") + if seed is not None: + raise NotImplementedError("The `mlx-lm` library does not support seed.") + if stop_at is not None: + raise NotImplementedError("The `mlx-lm` library does not support stop_at.") + + generate_kwargs = { + "temp": temperature, + "top_p": top_p, + "sampler": sampler, + "logits_processor": logits_processor, + } + + # Adapted from + # https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L267 + prompt_tokens = mx.array(self.mlx_tokenizer.encode(prompts)) + + for (token, prob), n in zip( + self.generate_step(prompt_tokens, **generate_kwargs), + range(max_tokens), + ): + if token == self.tokenizer.eos_token_id: + break + yield self.tokenizer.decode([token])[0] + + def generate_step( + self, + prompt: "mx.array", + temp: Optional[float], + top_p: Optional[float], + sampler: str, + logits_processor: "BaseLogitsProcessor", + ) -> Generator[Tuple[int, float], None, None]: + """ + Adapted from + https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L129 + + A generator producing token ids based on the given prompt from the model. + + Args: + prompt (mx.array): The input prompt. + temp (float): The temperature for sampling, if 0 the argmax is used. + Default: ``0``. + top_p (float, optional): Nulceus sampling, higher means model considers + more less likely words. + sampler (str): The sampler string defined by SequenceGeneratorAdapter + logits_processor (BaseLogitsProcessor): Augment logits before sampling. + """ + import mlx.core as mx + import mlx_lm + + temperature: float = temp or 1.0 + + def sample(logits: "mx.array") -> Tuple["mx.array", float]: + softmax_logits = mx.softmax(logits) + + if temperature == 0.0 or sampler == "greedy": + token = mx.argmax(logits, axis=-1) + elif sampler == "multinomial": + if top_p is not None and top_p > 0 and top_p < 1.0: + token = mlx_lm.sample_utils.top_p_sampling( + logits, top_p, temperature + ) + else: + token = mx.random.categorical(logits * (1 / temperature)) + else: + raise ValueError(f"Invalid mlx-lm sampler: `{sampler}`") + + prob = softmax_logits[0, token] + return token, prob + + kv_heads = ( + [self.model.n_kv_heads] * len(self.model.layers) + if isinstance(self.model.n_kv_heads, int) + else self.model.n_kv_heads + ) + cache = [mlx_lm.models.base.KVCache(self.model.head_dim, n) for n in kv_heads] + + # kv cache contains processed input IDs, we pass the unprocessed inputs and cache to model() + unprocessed_input_ids = prompt + generated_ids: List[int] = [] + + while True: + logits = self.model(unprocessed_input_ids[None], cache=cache) + logits = logits[:, -1, :] + + if logits_processor is not None: + # convert to logits_processor 1d expectation, apply, then convert back + logits_1d = logits.reshape(-1) + logits_1d = logits_processor(generated_ids, logits_1d) + logits = logits_1d.reshape(1, -1) + + new_token_single, prob = sample(logits) + new_token = new_token_single.item() + yield new_token, prob + + generated_ids.append(new_token) + unprocessed_input_ids = new_token_single + + +def mlxlm( + model_name: str, + tokenizer_config: dict = {}, + model_config: dict = {}, + adapter_path: Optional[str] = None, + lazy: bool = False, +): + """Instantiate a model from the `mlx_lm` library and its tokenizer. + + Signature adapted from + https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L422 + + Parameters + ---------- + Args: + path_or_hf_repo (Path): The path or the huggingface repository to load the model from. + tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer. + Defaults to an empty dictionary. + model_config(dict, optional): Configuration parameters specifically for the model. + Defaults to an empty dictionary. + adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers + to the model. Default: ``None``. + lazy (bool): If False eval the model parameters to make sure they are + loaded in memory before returning, otherwise they will be loaded + when needed. Default: ``False`` + + Returns + ------- + A `MLXLM` model instance. + + """ + try: + import mlx.core as mx + import mlx_lm + except ImportError: + raise ImportError( + "The `mlx_lm` library needs to be installed in order to use `mlx_lm` models." + ) + if not mx.metal.is_available(): + raise RuntimeError("You cannot use `mlx_lm` without Apple Silicon (Metal)") + + model, tokenizer = mlx_lm.load( + model_name, + tokenizer_config=tokenizer_config, + model_config=model_config, + adapter_path=adapter_path, + lazy=lazy, + ) + return MLXLM(model, tokenizer) diff --git a/outlines/processors/__init__.py b/outlines/processors/__init__.py new file mode 100644 index 000000000..5c6a697ed --- /dev/null +++ b/outlines/processors/__init__.py @@ -0,0 +1,7 @@ +from .structured import ( + BaseLogitsProcessor, + CFGLogitsProcessor, + FSMLogitsProcessor, + JSONLogitsProcessor, + RegexLogitsProcessor, +) diff --git a/outlines/processors/base_logits_processor.py b/outlines/processors/base_logits_processor.py new file mode 100644 index 000000000..dabfd91b0 --- /dev/null +++ b/outlines/processors/base_logits_processor.py @@ -0,0 +1,78 @@ +from abc import abstractmethod +from typing import List, Protocol, Union + +import numpy as np +import torch +from numpy.typing import NDArray + + +def is_mlx_array(logits): + try: + import mlx.core as mx + except ImportError: + return False + return isinstance(logits, mx.array) + + +class BaseLogitsProcessor(Protocol): + """ + Base class for logits processors which normalizes types of logits: + - ndarray (used by llama-cpp-python), converted to torch.Tensor + - torch.Tensor (used by everything else) + + Normalization of types and conversion to torch.Tensor + doesn't move memory, it just casts the type. + + Normalizing the types allows all logits processors inheriting from this class + to implement a single method for all the business logit: `process_logits()` + """ + + @abstractmethod + def process_logits( + self, input_ids: List[int], logits: torch.Tensor + ) -> torch.Tensor: + ... + + def __call__( + self, + input_ids: Union[NDArray[np.int64], List[int], torch.Tensor], + logits: Union[NDArray[np.float32], torch.Tensor], + ) -> Union[NDArray[np.int64], torch.Tensor]: + """ + Apply logits processor + Unify type + - convert input_ids: either ndarray, List[int], or Tensor -> List[int] + - convert logits: either ndarray, mlx array, Tensor -> Tensor + Call process_logits() to perform business logic + """ + with torch.no_grad(): + if not isinstance(input_ids, list): + input_ids = input_ids.tolist() + + if isinstance(logits, np.ndarray): + # Unify type, convert numpy array to Tensor + # from_numpy and .numpy() don't copy the data, it uses the same memory address + torch_logits = torch.from_numpy(logits) + processed_torch_logits = self.process_logits(input_ids, torch_logits) + return processed_torch_logits.detach().numpy() + + elif isinstance(logits, torch.Tensor): + return self.process_logits(input_ids, logits) + + elif is_mlx_array(logits): + # mlx -> torch -> mlx conversion docs: + # https://ml-explore.github.io/mlx/build/html/usage/numpy.html + import mlx.core as mx + + torch_logits = torch.from_dlpack(logits) + processed_torch_logits = self.process_logits(input_ids, torch_logits) + + # numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch + logits_float32_numpy = processed_torch_logits.float().numpy() + return mx.array(logits_float32_numpy) + + else: + raise TypeError( + "LogitsProcessor must be called with either np.NDArray" + ", torch.Tensor, or mlx.core.array typed logits" + ) diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py new file mode 100644 index 000000000..b8ef5b2da --- /dev/null +++ b/outlines/processors/structured.py @@ -0,0 +1,187 @@ +""" + _______________________________ +/ Don't want to self-host? \ +\\ Try .json at http://dottxt.co / + ------------------------------- + \\ ^__^ + \\ (oo)\\_______ + (__)\\ )\\/\ + ||----w | + || || + +Copyright 2024- the Outlines developers + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import math +from typing import TYPE_CHECKING, List, Optional, Type, Union + +import numpy as np +import torch +from numpy.typing import NDArray +from pydantic import BaseModel + +from outlines.fsm.guide import CFGGuide, Guide, RegexGuide +from outlines.fsm.json_schema import build_regex_from_schema +from outlines.integrations.utils import convert_json_schema_to_str + +from .base_logits_processor import BaseLogitsProcessor + +if TYPE_CHECKING: + from outlines.models.tokenizer import Tokenizer + + +class FSMLogitsProcessor(BaseLogitsProcessor): + """Bias generation using a finite state machine. + + Attributes + ---------- + tokenizer + The tokenizer used to convert tokens to ids. + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__(self, tokenizer: "Tokenizer", fsm: Guide): + """A FSM-based logits processor. + + Parameters + ---------- + tokenizer + The tokenizer used to convert tokens to ids. + fsm + The finite state machine which is used to bias the logits. + """ + self.tokenizer = tokenizer + self._fsm_state = 0 + self.fsm: Guide = fsm + self._is_first_token = True + + def process_logits( + self, input_ids: List[int], logits: torch.Tensor + ) -> NDArray[np.float32]: + """Use the FSM to bias the logits before sampling the next token. + + Parameters + ---------- + input_ids + The input token ids. + logits + The logits. + + Returns + ------- + torch.Tensor + The biased logits. + """ + if self._is_first_token: + self._is_first_token = False + else: + last_token = input_ids[-1] + self._fsm_state = self.fsm.get_next_state(self._fsm_state, last_token) + + allowed_tokens = self.fsm.get_next_instruction(self._fsm_state).tokens + allowed_tokens = torch.tensor(allowed_tokens, device=logits.device) + + mask = torch.full_like(logits, -math.inf) + mask[allowed_tokens] = logits[allowed_tokens] + return mask + + def copy(self) -> "FSMLogitsProcessor": + """Return a copy of the logits processor.""" + return FSMLogitsProcessor(tokenizer=self.tokenizer, fsm=self.fsm.copy()) + + +class RegexLogitsProcessor(FSMLogitsProcessor): + """Bias generation based on a regular expression. + + Attributes + ---------- + tokenizer + The tokenizer used to convert tokens to ids. + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__(self, regex_string: str, tokenizer: "Tokenizer"): + """Compile the FSM that drives the regex-guided generation. + + Parameters + ---------- + regex_string + A string that represents a regular expression + tokenizer + An Outlines tokenizer + """ + fsm = RegexGuide(regex_string, tokenizer) + super().__init__(tokenizer=tokenizer, fsm=fsm) + + +class JSONLogitsProcessor(RegexLogitsProcessor): + """Bias generation based on a JSON schema. + + Attributes + ---------- + tokenizer + The tokenizer used to convert tokens to ids. + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__( + self, + schema: Union[dict, Type[BaseModel], str], + tokenizer: "Tokenizer", + whitespace_pattern: Optional[str] = None, + ): + """Compile the FSM that drives the JSON-guided generation. + + Parameters + ---------- + schema + A JSON schema that encodes the structure we want the model to generate. + tokenizer + The tokenizer used to convert tokens to ids. + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string + literals). For example, to allow only a single space or newline with + `whitespace_pattern=r"[\n ]?"` + """ + schema_str = convert_json_schema_to_str(json_schema=schema) + regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + super().__init__(regex_string=regex_string, tokenizer=tokenizer) + + +class CFGLogitsProcessor(FSMLogitsProcessor): + """Bias generation based on a context-free grammar. + + Attributes + ---------- + tokenizer + The tokenizer used to convert tokens to ids. + fsm + The finite state machine which is used to bias the logits. + """ + + def __init__(self, cfg_str: str, tokenizer: "Tokenizer"): + """Compile the FSM that drives the CFG-guided generation. + + Parameters + ---------- + cfg_str + A string that represents a grammar + tokenizer + The tokenizer used to convert tokens to ids. + """ + cfg_automata = CFGGuide(cfg_string=cfg_str, tokenizer=tokenizer) + super().__init__(tokenizer=tokenizer, fsm=cfg_automata) diff --git a/pyproject.toml b/pyproject.toml index c6f72d3e3..d7d0db8b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ test = [ "beartype<0.16.0", "responses", "llama-cpp-python", + "mlx-lm", "huggingface_hub", "openai>=1.0.0", "vllm", @@ -110,6 +111,8 @@ module = [ "jsonschema.*", "openai.*", "mamba_ssm.*", + "mlx_lm.*", + "mlx.*", "nest_asyncio", "numpy.*", "cloudpickle.*", diff --git a/tests/generate/conftest.py b/tests/generate/conftest.py index ef7e40eed..5b3e6f79c 100644 --- a/tests/generate/conftest.py +++ b/tests/generate/conftest.py @@ -3,6 +3,24 @@ import pytest +def pytest_collection_modifyitems(config, items): + """If mlxlm and Metal aren't available, skip mlxlm tests""" + try: + import mlx.core as mx + import mlx_lm # noqa: F401 + + assert mx.metal.is_available() + except (ImportError, AssertionError): + skip_marker = pytest.mark.skip( + reason="Skipping test because mlx-lm or Metal are not available" + ) + for item in items: + if "model_fixture" in item.fixturenames: + model_param = item.callspec.params.get("model_fixture", None) + if model_param == "model_mlxlm": + item.add_marker(skip_marker) + + @pytest.fixture def temp_cache_dir(): import os diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py new file mode 100644 index 000000000..1f1a3aea2 --- /dev/null +++ b/tests/generate/test_generate.py @@ -0,0 +1,54 @@ +import re + +import pytest + +import outlines.generate as generate +import outlines.models as models + + +@pytest.fixture(scope="session") +def model_llamacpp(tmp_path_factory): + return models.llamacpp( + repo_id="M4-ai/TinyMistral-248M-v2-Instruct-GGUF", + filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf", + ) + + +@pytest.fixture(scope="session") +def model_mlxlm(tmp_path_factory): + return models.mlxlm("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit") + + +@pytest.fixture(scope="session") +def model_transformers(tmp_path_factory): + return models.transformers("Locutusque/TinyMistral-248M-v2-Instruct", device="cpu") + + +@pytest.mark.parametrize( + "model_fixture", + ("model_llamacpp", "model_mlxlm", "model_transformers"), +) +def test_generate_text(request, model_fixture): + model = request.getfixturevalue(model_fixture) + generator = generate.text(model) + res = generator("test", max_tokens=10) + assert isinstance(res, str) + + +@pytest.mark.parametrize( + "model_fixture", + ("model_llamacpp", "model_mlxlm", "model_transformers"), +) +@pytest.mark.parametrize( + "pattern", + ( + "[0-9]", + "abc*", + "\\+?[1-9][0-9]{7,14}", + ), +) +def test_generate_regex(request, model_fixture, pattern): + model = request.getfixturevalue(model_fixture) + generator = generate.regex(model, pattern) + res = generator("foobarbaz", max_tokens=20) + assert re.match(pattern, res) is not None, res