From 9103d06dee9321a631e4fa01ad3526372fea0b62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Sun, 14 Apr 2024 21:06:16 +0200 Subject: [PATCH] Make `torch` and `transformers` imports optional --- docs/reference/models/transformers.md | 6 ++- outlines/generate/api.py | 19 +++++---- outlines/generate/generator.py | 56 ++++++++++++++++----------- outlines/models/transformers.py | 45 +++++++++++---------- outlines/samplers.py | 51 ++++++++++++++---------- pyproject.toml | 7 ++-- 6 files changed, 108 insertions(+), 76 deletions(-) diff --git a/docs/reference/models/transformers.md b/docs/reference/models/transformers.md index 2d9880a6b..286df4367 100644 --- a/docs/reference/models/transformers.md +++ b/docs/reference/models/transformers.md @@ -3,7 +3,11 @@ !!! Installation - You need to install the `transformer` and `datasets` libraries to be able to use these models in Outlines. + You need to install the `transformer`, `datasets` and `torch` libraries to be able to use these models in Outlines: + + ```bash + pip install torch transformers datasets + ``` Outlines provides an integration with the `torch` implementation of causal models in the [transformers][transformers] library. You can initialize the model by passing its name: diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 4fa0a3e79..3f4f182d2 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -1,12 +1,13 @@ import datetime from dataclasses import dataclass -from typing import Iterator, List, Optional, Union - -import torch +from typing import TYPE_CHECKING, Iterator, List, Optional, Union from outlines.generate.generator import sequence_generator from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler +if TYPE_CHECKING: + import torch + FormattedOutput = Union[ str, int, float, bool, datetime.date, datetime.time, datetime.datetime ] @@ -29,9 +30,9 @@ def __init__( def get_generated_token_ids( self, - prompt_token_ids: torch.Tensor, - token_ids: torch.Tensor, - ) -> List[torch.Tensor]: + prompt_token_ids: "torch.Tensor", + token_ids: "torch.Tensor", + ) -> List["torch.Tensor"]: """Get the tokens generated so far. Parameters @@ -130,7 +131,7 @@ def __call__( prompts: Union[str, List[str]], max_tokens: Optional[int] = None, stop_at: Optional[Union[str, List[str]]] = None, - rng: Optional[torch.Generator] = None, + rng: Optional["torch.Generator"] = None, ) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]: """Generate the full text sequence. @@ -157,6 +158,7 @@ def __call__( ------- The generation(s), potentially cast to another type. """ + import torch if isinstance(prompts, str): prompts = [prompts] @@ -247,7 +249,7 @@ def stream( prompts: Union[str, List[str]], max_tokens: Optional[int] = None, stop_at: Optional[Union[str, List[str]]] = None, - rng: Optional[torch.Generator] = None, + rng: Optional["torch.Generator"] = None, ) -> Iterator[Union[List[str], str, List[List[str]]]]: """Generate the text sequence one token at a time. @@ -274,6 +276,7 @@ def stream( A string or list of strings that contain the generated text. """ + import torch if isinstance(prompts, str): prompts = [prompts] diff --git a/outlines/generate/generator.py b/outlines/generate/generator.py index ad8ae8537..ca5fa395f 100644 --- a/outlines/generate/generator.py +++ b/outlines/generate/generator.py @@ -2,9 +2,9 @@ import math from typing import TYPE_CHECKING, Callable, Iterator, List, Optional, Tuple -import torch - if TYPE_CHECKING: + import torch + from outlines.fsm.guide import Guide @@ -14,10 +14,10 @@ class ContextLengthExceededError(Exception): @dataclasses.dataclass(frozen=True) class GenerationState: - token_ids: torch.Tensor - kv_cache: torch.Tensor - logits: torch.Tensor - weights: torch.Tensor + token_ids: "torch.Tensor" + kv_cache: "torch.Tensor" + logits: "torch.Tensor" + weights: "torch.Tensor" fsm_states: List[int] @@ -25,11 +25,11 @@ def sequence_generator( model: Callable, sampler: Callable, fsms: List["Guide"], - token_ids: torch.Tensor, - sequence_weights: torch.Tensor, - attention_masks: torch.Tensor, + token_ids: "torch.Tensor", + sequence_weights: "torch.Tensor", + attention_masks: "torch.Tensor", fsm_states: List[int], - rng: torch.Generator = torch.Generator(), + rng: "torch.Generator", ) -> Iterator[GenerationState]: """Generates sequences of tokens. @@ -62,6 +62,11 @@ def sequence_generator( A new sequence. """ + import torch + + if rng is None: + rng = torch.Generator() + kv_cache = None while True: @@ -107,7 +112,7 @@ def sequence_generator( def get_next_fsm_states( - fsms: List["Guide"], fsm_states: List[int], next_token_ids: torch.Tensor + fsms: List["Guide"], fsm_states: List[int], next_token_ids: "torch.Tensor" ) -> List[int]: """ @@ -129,7 +134,7 @@ def get_next_fsm_states( ] -def get_allowed_tokens(fsms: List["Guide"], fsm_states: List[int]) -> torch.Tensor: +def get_allowed_tokens(fsms: List["Guide"], fsm_states: List[int]) -> "torch.Tensor": """Get the new instructions for each sequence from the finite-state machine. Parameters @@ -173,10 +178,9 @@ def is_generation_finished(fsms: List["Guide"], fsm_states: List[int]) -> bool: return all([fsm.is_final_state(state) for fsm, state in zip(fsms, fsm_states)]) -@torch.inference_mode() def update_token_ids( - token_ids: torch.Tensor, next_token_ids: torch.Tensor, ancestors: torch.Tensor -) -> torch.Tensor: + token_ids: "torch.Tensor", next_token_ids: "torch.Tensor", ancestors: "torch.Tensor" +) -> "torch.Tensor": """Append the sampled tokens to the running sequence of tokens. Parameters @@ -195,14 +199,15 @@ def update_token_ids( just generated. """ + import torch + token_ids = torch.index_select(token_ids, 0, ancestors) return torch.concatenate([token_ids, next_token_ids], dim=-1) -@torch.inference_mode() def update_attention_masks( - attention_masks: torch.Tensor, ancestors: torch.Tensor -) -> torch.Tensor: + attention_masks: "torch.Tensor", ancestors: "torch.Tensor" +) -> "torch.Tensor": """Expand the attention masks. Parameters @@ -217,6 +222,8 @@ def update_attention_masks( The attention masks padded with 1s. """ + import torch + attention_masks = torch.index_select(attention_masks, 0, ancestors) return torch.concatenate( [ @@ -229,7 +236,7 @@ def update_attention_masks( ) -def reorder_fsms(fsms: List["Guide"], ancestors: torch.Tensor) -> List["Guide"]: +def reorder_fsms(fsms: List["Guide"], ancestors: "torch.Tensor") -> List["Guide"]: reordered_fsms = [] for ancestor in ancestors: reordered_fsms.append(fsms[ancestor].copy()) @@ -237,7 +244,7 @@ def reorder_fsms(fsms: List["Guide"], ancestors: torch.Tensor) -> List["Guide"]: return reordered_fsms -def reorder_fsm_states(fsm_states: List[int], ancestors: torch.Tensor) -> List[int]: +def reorder_fsm_states(fsm_states: List[int], ancestors: "torch.Tensor") -> List[int]: reordered_states = [] for ancestor in ancestors: reordered_states.append(fsm_states[ancestor]) @@ -246,7 +253,7 @@ def reorder_fsm_states(fsm_states: List[int], ancestors: torch.Tensor) -> List[i def reorder_kv_cache( - kv_cache: Optional[Tuple], ancestors: torch.Tensor + kv_cache: Optional[Tuple], ancestors: "torch.Tensor" ) -> Optional[Tuple]: """Re-order the KV-cache based on the ancestors. @@ -256,6 +263,8 @@ def reorder_kv_cache( first dimension is the batch size. """ + import torch + if kv_cache is None: return None @@ -270,8 +279,7 @@ def reorder_kv_cache( return new_kv_cache -@torch.inference_mode() -def bias_logits(logits: torch.Tensor, allowed_token_ids: List) -> torch.Tensor: +def bias_logits(logits: "torch.Tensor", allowed_token_ids: List) -> "torch.Tensor": """Mask the logits. The function iterates over a nested list where each list corresponds to the @@ -290,6 +298,8 @@ def bias_logits(logits: torch.Tensor, allowed_token_ids: List) -> torch.Tensor: A view of the original logits tensor where some values are masked. """ + import torch + biased_logits = torch.full_like(logits, -math.inf, device=logits.device) for i, ids in enumerate(allowed_token_ids): biased_logits[i, ids] = logits[i, ids] diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 1b29ee2f4..3bc59412e 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -1,16 +1,15 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union -import torch - from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: + import torch from transformers import PreTrainedModel, PreTrainedTokenizer __all__ = ["transformers"] -KVCacheType = Tuple[Tuple[torch.DoubleTensor, torch.DoubleTensor], ...] +KVCacheType = Tuple[Tuple["torch.DoubleTensor", "torch.DoubleTensor"], ...] def get_llama_tokenizer_types(): @@ -77,13 +76,13 @@ def __init__(self, tokenizer: "PreTrainedTokenizer", **kwargs): def encode( self, prompt: Union[str, List[str]], **kwargs - ) -> Tuple[torch.LongTensor, torch.LongTensor]: + ) -> Tuple["torch.LongTensor", "torch.LongTensor"]: kwargs["padding"] = True kwargs["return_tensors"] = "pt" output = self.tokenizer(prompt, **kwargs) return output["input_ids"], output["attention_mask"] - def decode(self, token_ids: torch.LongTensor) -> List[str]: + def decode(self, token_ids: "torch.LongTensor") -> List[str]: text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) return text @@ -127,13 +126,12 @@ def __init__( self.model = model self.tokenizer = TransformerTokenizer(tokenizer) - @torch.inference_mode def forward( self, - input_ids: torch.LongTensor, - attention_mask: torch.LongTensor, + input_ids: "torch.LongTensor", + attention_mask: "torch.LongTensor", past_key_values: Optional[Tuple] = None, - ) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]: + ) -> Tuple["torch.FloatTensor", Optional[KVCacheType]]: """Compute a forward pass through the transformer model. Parameters @@ -151,28 +149,35 @@ def forward( The computed logits and the new cached key and value tensors. """ + try: + import torch + except ImportError: + ImportError( + "The `torch` library needs to be installed to use `transformers` models." + ) assert 0 < input_ids.ndim < 3 if past_key_values: input_ids = input_ids[..., -1].unsqueeze(-1) - output = self.model( - input_ids, - attention_mask=attention_mask, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - past_key_values=past_key_values, - ) + with torch.inference_mode(): + output = self.model( + input_ids, + attention_mask=attention_mask, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + past_key_values=past_key_values, + ) return output.logits, output.past_key_values def __call__( self, - input_ids: torch.LongTensor, - attention_mask: torch.LongTensor, + input_ids: "torch.LongTensor", + attention_mask: "torch.LongTensor", past_key_values: Optional[Tuple] = None, - ) -> torch.FloatTensor: + ) -> "torch.FloatTensor": logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values) next_token_logits = logits[..., -1, :] diff --git a/outlines/samplers.py b/outlines/samplers.py index cdfcbe2ed..8b64ed768 100644 --- a/outlines/samplers.py +++ b/outlines/samplers.py @@ -1,7 +1,8 @@ import math -from typing import Callable, Optional, Protocol, Tuple +from typing import TYPE_CHECKING, Callable, Optional, Protocol, Tuple -import torch +if TYPE_CHECKING: + import torch class Sampler(Protocol): @@ -9,10 +10,10 @@ class Sampler(Protocol): def __call__( self, - next_token_logits: torch.DoubleTensor, - sequence_weights: torch.DoubleTensor, - rng: torch.Generator, - ) -> torch.DoubleTensor: + next_token_logits: "torch.DoubleTensor", + sequence_weights: "torch.DoubleTensor", + rng: "torch.Generator", + ) -> "torch.DoubleTensor": ... @@ -38,10 +39,10 @@ def __init__(self): def __call__( self, - next_token_logits: torch.DoubleTensor, - sequence_weights: torch.DoubleTensor, + next_token_logits: "torch.DoubleTensor", + sequence_weights: "torch.DoubleTensor", _, - ) -> torch.DoubleTensor: + ) -> "torch.DoubleTensor": """Call the greedy sampler. Parameters @@ -63,6 +64,8 @@ def __call__( cumulative weights of each sequence of shape ``(n_seqs,)``. """ + import torch + logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) next_token_ids = torch.argmax(logprobs, dim=-1, keepdim=True) @@ -116,10 +119,10 @@ def __init__( def __call__( self, - next_token_logits: torch.DoubleTensor, - sequence_weights: torch.DoubleTensor, - rng: torch.Generator, - ) -> Tuple[torch.DoubleTensor, torch.DoubleTensor, torch.DoubleTensor]: + next_token_logits: "torch.DoubleTensor", + sequence_weights: "torch.DoubleTensor", + rng: "torch.Generator", + ) -> Tuple["torch.DoubleTensor", "torch.DoubleTensor", "torch.DoubleTensor"]: """Call the multinomial sampler. Parameters @@ -141,6 +144,8 @@ def __call__( cumulative weights of each sequence of shape ``(n_seqs,)``. """ + import torch + altered_next_token_logits = next_token_logits for logit_processor in self.logits_processors: altered_next_token_logits = logit_processor(next_token_logits) @@ -160,7 +165,7 @@ def __call__( multinomial = MultinomialSampler -def keep_top_k_logits(k: int) -> Callable[[torch.Tensor], torch.Tensor]: +def keep_top_k_logits(k: int) -> Callable[["torch.Tensor"], "torch.Tensor"]: """Build a function that masks logits values smaller than the top `k` ones. Parameters @@ -169,6 +174,8 @@ def keep_top_k_logits(k: int) -> Callable[[torch.Tensor], torch.Tensor]: The ranking below which logit values are replaced by `-math.inf`. """ + import torch + if not isinstance(k, int) or k < 1: raise ValueError(f"`k` must be a strictly positive integers, got {k} instead.") @@ -180,7 +187,7 @@ def logits_processor(logits: torch.Tensor) -> torch.Tensor: return logits_processor -def keep_top_p_logits(p: float) -> Callable[[torch.Tensor], torch.Tensor]: +def keep_top_p_logits(p: float) -> Callable[["torch.Tensor"], "torch.Tensor"]: """Build a function that masks the lowest probability tokens whose cumulative probability is below a certain threshold. @@ -192,6 +199,8 @@ def keep_top_p_logits(p: float) -> Callable[[torch.Tensor], torch.Tensor]: others. Its value must be between 0 (excluded) and 1 (included). """ + import torch + if p <= 0.0 or p > 1.0: raise ValueError( f"`p` must be a floating point number between 0 (excluded) and 1 (included), got {p} instead." @@ -210,7 +219,7 @@ def logits_processor(logits: torch.Tensor) -> torch.Tensor: return logits_processor -def rescale_logits(temperature: float) -> Callable[[torch.Tensor], torch.Tensor]: +def rescale_logits(temperature: float) -> Callable[["torch.Tensor"], "torch.Tensor"]: """Build a function that rescales the token probabilities exponentially. Parameters @@ -229,7 +238,7 @@ def rescale_logits(temperature: float) -> Callable[[torch.Tensor], torch.Tensor] "Please use the greedy sampler instead of setting the temperature to 0." ) - def logits_processor(logits: torch.Tensor) -> torch.Tensor: + def logits_processor(logits: "torch.Tensor") -> "torch.Tensor": return logits / temperature return logits_processor @@ -250,10 +259,10 @@ def __init__(self, beams: int = 1): def __call__( self, - next_token_logits: torch.DoubleTensor, - sequence_weights: torch.DoubleTensor, + next_token_logits: "torch.DoubleTensor", + sequence_weights: "torch.DoubleTensor", _, - ) -> Tuple[torch.DoubleTensor, torch.DoubleTensor, torch.DoubleTensor]: + ) -> Tuple["torch.DoubleTensor", "torch.DoubleTensor", "torch.DoubleTensor"]: """Call the beam search sampler. Parameters @@ -275,6 +284,8 @@ def __call__( cumulative weights of each sequence of shape ``(n_seqs,)``. """ + import torch + logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) weights = logprobs + sequence_weights.unsqueeze(1).expand_as(next_token_logits) diff --git a/pyproject.toml b/pyproject.toml index 3137f281f..5a0cc6986 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,12 +31,10 @@ dependencies = [ "cloudpickle", "diskcache", "pydantic>=2.0", - "torch>=2.1.0", "numba", "referencing", "jsonschema", "requests", - "transformers", ] dynamic = ["version"] @@ -47,7 +45,6 @@ test = [ "pytest-benchmark", "pytest-cov", "pytest-mock", - "transformers", "coverage[toml]>=5.1", "diff-cover", "accelerate", @@ -57,7 +54,9 @@ test = [ "llama-cpp-python", "huggingface_hub", "openai>=1.0.0", - "vllm" + "vllm", + "torch", + "transformers", ] serve = [ "vllm>=0.3.0",