Skip to content

Commit

Permalink
Make torch and transformers imports optional
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Apr 15, 2024
1 parent 5f98bd3 commit 9103d06
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 76 deletions.
6 changes: 5 additions & 1 deletion docs/reference/models/transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

!!! Installation

You need to install the `transformer` and `datasets` libraries to be able to use these models in Outlines.
You need to install the `transformer`, `datasets` and `torch` libraries to be able to use these models in Outlines:

```bash
pip install torch transformers datasets
```


Outlines provides an integration with the `torch` implementation of causal models in the [transformers][transformers] library. You can initialize the model by passing its name:
Expand Down
19 changes: 11 additions & 8 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import datetime
from dataclasses import dataclass
from typing import Iterator, List, Optional, Union

import torch
from typing import TYPE_CHECKING, Iterator, List, Optional, Union

from outlines.generate.generator import sequence_generator
from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler

if TYPE_CHECKING:
import torch

FormattedOutput = Union[
str, int, float, bool, datetime.date, datetime.time, datetime.datetime
]
Expand All @@ -29,9 +30,9 @@ def __init__(

def get_generated_token_ids(
self,
prompt_token_ids: torch.Tensor,
token_ids: torch.Tensor,
) -> List[torch.Tensor]:
prompt_token_ids: "torch.Tensor",
token_ids: "torch.Tensor",
) -> List["torch.Tensor"]:
"""Get the tokens generated so far.
Parameters
Expand Down Expand Up @@ -130,7 +131,7 @@ def __call__(
prompts: Union[str, List[str]],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
rng: Optional[torch.Generator] = None,
rng: Optional["torch.Generator"] = None,
) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]:
"""Generate the full text sequence.
Expand All @@ -157,6 +158,7 @@ def __call__(
-------
The generation(s), potentially cast to another type.
"""
import torch

if isinstance(prompts, str):
prompts = [prompts]
Expand Down Expand Up @@ -247,7 +249,7 @@ def stream(
prompts: Union[str, List[str]],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
rng: Optional[torch.Generator] = None,
rng: Optional["torch.Generator"] = None,
) -> Iterator[Union[List[str], str, List[List[str]]]]:
"""Generate the text sequence one token at a time.
Expand All @@ -274,6 +276,7 @@ def stream(
A string or list of strings that contain the generated text.
"""
import torch

if isinstance(prompts, str):
prompts = [prompts]
Expand Down
56 changes: 33 additions & 23 deletions outlines/generate/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import math
from typing import TYPE_CHECKING, Callable, Iterator, List, Optional, Tuple

import torch

if TYPE_CHECKING:
import torch

from outlines.fsm.guide import Guide


Expand All @@ -14,22 +14,22 @@ class ContextLengthExceededError(Exception):

@dataclasses.dataclass(frozen=True)
class GenerationState:
token_ids: torch.Tensor
kv_cache: torch.Tensor
logits: torch.Tensor
weights: torch.Tensor
token_ids: "torch.Tensor"
kv_cache: "torch.Tensor"
logits: "torch.Tensor"
weights: "torch.Tensor"
fsm_states: List[int]


def sequence_generator(
model: Callable,
sampler: Callable,
fsms: List["Guide"],
token_ids: torch.Tensor,
sequence_weights: torch.Tensor,
attention_masks: torch.Tensor,
token_ids: "torch.Tensor",
sequence_weights: "torch.Tensor",
attention_masks: "torch.Tensor",
fsm_states: List[int],
rng: torch.Generator = torch.Generator(),
rng: "torch.Generator",
) -> Iterator[GenerationState]:
"""Generates sequences of tokens.
Expand Down Expand Up @@ -62,6 +62,11 @@ def sequence_generator(
A new sequence.
"""
import torch

if rng is None:
rng = torch.Generator()

kv_cache = None

while True:
Expand Down Expand Up @@ -107,7 +112,7 @@ def sequence_generator(


def get_next_fsm_states(
fsms: List["Guide"], fsm_states: List[int], next_token_ids: torch.Tensor
fsms: List["Guide"], fsm_states: List[int], next_token_ids: "torch.Tensor"
) -> List[int]:
"""
Expand All @@ -129,7 +134,7 @@ def get_next_fsm_states(
]


def get_allowed_tokens(fsms: List["Guide"], fsm_states: List[int]) -> torch.Tensor:
def get_allowed_tokens(fsms: List["Guide"], fsm_states: List[int]) -> "torch.Tensor":
"""Get the new instructions for each sequence from the finite-state machine.
Parameters
Expand Down Expand Up @@ -173,10 +178,9 @@ def is_generation_finished(fsms: List["Guide"], fsm_states: List[int]) -> bool:
return all([fsm.is_final_state(state) for fsm, state in zip(fsms, fsm_states)])


@torch.inference_mode()
def update_token_ids(
token_ids: torch.Tensor, next_token_ids: torch.Tensor, ancestors: torch.Tensor
) -> torch.Tensor:
token_ids: "torch.Tensor", next_token_ids: "torch.Tensor", ancestors: "torch.Tensor"
) -> "torch.Tensor":
"""Append the sampled tokens to the running sequence of tokens.
Parameters
Expand All @@ -195,14 +199,15 @@ def update_token_ids(
just generated.
"""
import torch

token_ids = torch.index_select(token_ids, 0, ancestors)
return torch.concatenate([token_ids, next_token_ids], dim=-1)


@torch.inference_mode()
def update_attention_masks(
attention_masks: torch.Tensor, ancestors: torch.Tensor
) -> torch.Tensor:
attention_masks: "torch.Tensor", ancestors: "torch.Tensor"
) -> "torch.Tensor":
"""Expand the attention masks.
Parameters
Expand All @@ -217,6 +222,8 @@ def update_attention_masks(
The attention masks padded with 1s.
"""
import torch

attention_masks = torch.index_select(attention_masks, 0, ancestors)
return torch.concatenate(
[
Expand All @@ -229,15 +236,15 @@ def update_attention_masks(
)


def reorder_fsms(fsms: List["Guide"], ancestors: torch.Tensor) -> List["Guide"]:
def reorder_fsms(fsms: List["Guide"], ancestors: "torch.Tensor") -> List["Guide"]:
reordered_fsms = []
for ancestor in ancestors:
reordered_fsms.append(fsms[ancestor].copy())

return reordered_fsms


def reorder_fsm_states(fsm_states: List[int], ancestors: torch.Tensor) -> List[int]:
def reorder_fsm_states(fsm_states: List[int], ancestors: "torch.Tensor") -> List[int]:
reordered_states = []
for ancestor in ancestors:
reordered_states.append(fsm_states[ancestor])
Expand All @@ -246,7 +253,7 @@ def reorder_fsm_states(fsm_states: List[int], ancestors: torch.Tensor) -> List[i


def reorder_kv_cache(
kv_cache: Optional[Tuple], ancestors: torch.Tensor
kv_cache: Optional[Tuple], ancestors: "torch.Tensor"
) -> Optional[Tuple]:
"""Re-order the KV-cache based on the ancestors.
Expand All @@ -256,6 +263,8 @@ def reorder_kv_cache(
first dimension is the batch size.
"""
import torch

if kv_cache is None:
return None

Expand All @@ -270,8 +279,7 @@ def reorder_kv_cache(
return new_kv_cache


@torch.inference_mode()
def bias_logits(logits: torch.Tensor, allowed_token_ids: List) -> torch.Tensor:
def bias_logits(logits: "torch.Tensor", allowed_token_ids: List) -> "torch.Tensor":
"""Mask the logits.
The function iterates over a nested list where each list corresponds to the
Expand All @@ -290,6 +298,8 @@ def bias_logits(logits: torch.Tensor, allowed_token_ids: List) -> torch.Tensor:
A view of the original logits tensor where some values are masked.
"""
import torch

biased_logits = torch.full_like(logits, -math.inf, device=logits.device)
for i, ids in enumerate(allowed_token_ids):
biased_logits[i, ids] = logits[i, ids]
Expand Down
45 changes: 25 additions & 20 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import torch

from outlines.models.tokenizer import Tokenizer

if TYPE_CHECKING:
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer

__all__ = ["transformers"]


KVCacheType = Tuple[Tuple[torch.DoubleTensor, torch.DoubleTensor], ...]
KVCacheType = Tuple[Tuple["torch.DoubleTensor", "torch.DoubleTensor"], ...]


def get_llama_tokenizer_types():
Expand Down Expand Up @@ -77,13 +76,13 @@ def __init__(self, tokenizer: "PreTrainedTokenizer", **kwargs):

def encode(
self, prompt: Union[str, List[str]], **kwargs
) -> Tuple[torch.LongTensor, torch.LongTensor]:
) -> Tuple["torch.LongTensor", "torch.LongTensor"]:
kwargs["padding"] = True
kwargs["return_tensors"] = "pt"
output = self.tokenizer(prompt, **kwargs)
return output["input_ids"], output["attention_mask"]

def decode(self, token_ids: torch.LongTensor) -> List[str]:
def decode(self, token_ids: "torch.LongTensor") -> List[str]:
text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
return text

Expand Down Expand Up @@ -127,13 +126,12 @@ def __init__(
self.model = model
self.tokenizer = TransformerTokenizer(tokenizer)

@torch.inference_mode
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
input_ids: "torch.LongTensor",
attention_mask: "torch.LongTensor",
past_key_values: Optional[Tuple] = None,
) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]:
) -> Tuple["torch.FloatTensor", Optional[KVCacheType]]:
"""Compute a forward pass through the transformer model.
Parameters
Expand All @@ -151,28 +149,35 @@ def forward(
The computed logits and the new cached key and value tensors.
"""
try:
import torch
except ImportError:
ImportError(
"The `torch` library needs to be installed to use `transformers` models."
)
assert 0 < input_ids.ndim < 3

if past_key_values:
input_ids = input_ids[..., -1].unsqueeze(-1)

output = self.model(
input_ids,
attention_mask=attention_mask,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
past_key_values=past_key_values,
)
with torch.inference_mode():
output = self.model(
input_ids,
attention_mask=attention_mask,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
past_key_values=past_key_values,
)

return output.logits, output.past_key_values

def __call__(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
input_ids: "torch.LongTensor",
attention_mask: "torch.LongTensor",
past_key_values: Optional[Tuple] = None,
) -> torch.FloatTensor:
) -> "torch.FloatTensor":
logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values)
next_token_logits = logits[..., -1, :]

Expand Down
Loading

0 comments on commit 9103d06

Please sign in to comment.