Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make torch and transformers imports optional #815

Merged
merged 1 commit into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/reference/models/transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

!!! Installation

You need to install the `transformer` and `datasets` libraries to be able to use these models in Outlines.
You need to install the `transformer`, `datasets` and `torch` libraries to be able to use these models in Outlines:

```bash
pip install torch transformers datasets
```


Outlines provides an integration with the `torch` implementation of causal models in the [transformers][transformers] library. You can initialize the model by passing its name:
Expand Down
19 changes: 11 additions & 8 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import datetime
from dataclasses import dataclass
from typing import Iterator, List, Optional, Union

import torch
from typing import TYPE_CHECKING, Iterator, List, Optional, Union

from outlines.generate.generator import sequence_generator
from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler

if TYPE_CHECKING:
import torch

FormattedOutput = Union[
str, int, float, bool, datetime.date, datetime.time, datetime.datetime
]
Expand All @@ -29,9 +30,9 @@ def __init__(

def get_generated_token_ids(
self,
prompt_token_ids: torch.Tensor,
token_ids: torch.Tensor,
) -> List[torch.Tensor]:
prompt_token_ids: "torch.Tensor",
token_ids: "torch.Tensor",
) -> List["torch.Tensor"]:
"""Get the tokens generated so far.

Parameters
Expand Down Expand Up @@ -130,7 +131,7 @@ def __call__(
prompts: Union[str, List[str]],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
rng: Optional[torch.Generator] = None,
rng: Optional["torch.Generator"] = None,
) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]:
"""Generate the full text sequence.

Expand All @@ -157,6 +158,7 @@ def __call__(
-------
The generation(s), potentially cast to another type.
"""
import torch

if isinstance(prompts, str):
prompts = [prompts]
Expand Down Expand Up @@ -247,7 +249,7 @@ def stream(
prompts: Union[str, List[str]],
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
rng: Optional[torch.Generator] = None,
rng: Optional["torch.Generator"] = None,
) -> Iterator[Union[List[str], str, List[List[str]]]]:
"""Generate the text sequence one token at a time.

Expand All @@ -274,6 +276,7 @@ def stream(
A string or list of strings that contain the generated text.

"""
import torch

if isinstance(prompts, str):
prompts = [prompts]
Expand Down
56 changes: 33 additions & 23 deletions outlines/generate/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import math
from typing import TYPE_CHECKING, Callable, Iterator, List, Optional, Tuple

import torch

if TYPE_CHECKING:
import torch

from outlines.fsm.guide import Guide


Expand All @@ -14,22 +14,22 @@ class ContextLengthExceededError(Exception):

@dataclasses.dataclass(frozen=True)
class GenerationState:
token_ids: torch.Tensor
kv_cache: torch.Tensor
logits: torch.Tensor
weights: torch.Tensor
token_ids: "torch.Tensor"
kv_cache: "torch.Tensor"
logits: "torch.Tensor"
weights: "torch.Tensor"
fsm_states: List[int]


def sequence_generator(
model: Callable,
sampler: Callable,
fsms: List["Guide"],
token_ids: torch.Tensor,
sequence_weights: torch.Tensor,
attention_masks: torch.Tensor,
token_ids: "torch.Tensor",
sequence_weights: "torch.Tensor",
attention_masks: "torch.Tensor",
fsm_states: List[int],
rng: torch.Generator = torch.Generator(),
rng: "torch.Generator",
) -> Iterator[GenerationState]:
"""Generates sequences of tokens.

Expand Down Expand Up @@ -62,6 +62,11 @@ def sequence_generator(
A new sequence.

"""
import torch

if rng is None:
rng = torch.Generator()

kv_cache = None

while True:
Expand Down Expand Up @@ -107,7 +112,7 @@ def sequence_generator(


def get_next_fsm_states(
fsms: List["Guide"], fsm_states: List[int], next_token_ids: torch.Tensor
fsms: List["Guide"], fsm_states: List[int], next_token_ids: "torch.Tensor"
) -> List[int]:
"""

Expand All @@ -129,7 +134,7 @@ def get_next_fsm_states(
]


def get_allowed_tokens(fsms: List["Guide"], fsm_states: List[int]) -> torch.Tensor:
def get_allowed_tokens(fsms: List["Guide"], fsm_states: List[int]) -> "torch.Tensor":
"""Get the new instructions for each sequence from the finite-state machine.

Parameters
Expand Down Expand Up @@ -173,10 +178,9 @@ def is_generation_finished(fsms: List["Guide"], fsm_states: List[int]) -> bool:
return all([fsm.is_final_state(state) for fsm, state in zip(fsms, fsm_states)])


@torch.inference_mode()
def update_token_ids(
token_ids: torch.Tensor, next_token_ids: torch.Tensor, ancestors: torch.Tensor
) -> torch.Tensor:
token_ids: "torch.Tensor", next_token_ids: "torch.Tensor", ancestors: "torch.Tensor"
) -> "torch.Tensor":
"""Append the sampled tokens to the running sequence of tokens.

Parameters
Expand All @@ -195,14 +199,15 @@ def update_token_ids(
just generated.

"""
import torch

token_ids = torch.index_select(token_ids, 0, ancestors)
return torch.concatenate([token_ids, next_token_ids], dim=-1)


@torch.inference_mode()
def update_attention_masks(
attention_masks: torch.Tensor, ancestors: torch.Tensor
) -> torch.Tensor:
attention_masks: "torch.Tensor", ancestors: "torch.Tensor"
) -> "torch.Tensor":
"""Expand the attention masks.

Parameters
Expand All @@ -217,6 +222,8 @@ def update_attention_masks(
The attention masks padded with 1s.

"""
import torch

attention_masks = torch.index_select(attention_masks, 0, ancestors)
return torch.concatenate(
[
Expand All @@ -229,15 +236,15 @@ def update_attention_masks(
)


def reorder_fsms(fsms: List["Guide"], ancestors: torch.Tensor) -> List["Guide"]:
def reorder_fsms(fsms: List["Guide"], ancestors: "torch.Tensor") -> List["Guide"]:
reordered_fsms = []
for ancestor in ancestors:
reordered_fsms.append(fsms[ancestor].copy())

return reordered_fsms


def reorder_fsm_states(fsm_states: List[int], ancestors: torch.Tensor) -> List[int]:
def reorder_fsm_states(fsm_states: List[int], ancestors: "torch.Tensor") -> List[int]:
reordered_states = []
for ancestor in ancestors:
reordered_states.append(fsm_states[ancestor])
Expand All @@ -246,7 +253,7 @@ def reorder_fsm_states(fsm_states: List[int], ancestors: torch.Tensor) -> List[i


def reorder_kv_cache(
kv_cache: Optional[Tuple], ancestors: torch.Tensor
kv_cache: Optional[Tuple], ancestors: "torch.Tensor"
) -> Optional[Tuple]:
"""Re-order the KV-cache based on the ancestors.

Expand All @@ -256,6 +263,8 @@ def reorder_kv_cache(
first dimension is the batch size.

"""
import torch

if kv_cache is None:
return None

Expand All @@ -270,8 +279,7 @@ def reorder_kv_cache(
return new_kv_cache


@torch.inference_mode()
def bias_logits(logits: torch.Tensor, allowed_token_ids: List) -> torch.Tensor:
def bias_logits(logits: "torch.Tensor", allowed_token_ids: List) -> "torch.Tensor":
"""Mask the logits.

The function iterates over a nested list where each list corresponds to the
Expand All @@ -290,6 +298,8 @@ def bias_logits(logits: torch.Tensor, allowed_token_ids: List) -> torch.Tensor:
A view of the original logits tensor where some values are masked.

"""
import torch

biased_logits = torch.full_like(logits, -math.inf, device=logits.device)
for i, ids in enumerate(allowed_token_ids):
biased_logits[i, ids] = logits[i, ids]
Expand Down
45 changes: 25 additions & 20 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import torch

from outlines.models.tokenizer import Tokenizer

if TYPE_CHECKING:
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer

__all__ = ["transformers"]


KVCacheType = Tuple[Tuple[torch.DoubleTensor, torch.DoubleTensor], ...]
KVCacheType = Tuple[Tuple["torch.DoubleTensor", "torch.DoubleTensor"], ...]


def get_llama_tokenizer_types():
Expand Down Expand Up @@ -77,13 +76,13 @@ def __init__(self, tokenizer: "PreTrainedTokenizer", **kwargs):

def encode(
self, prompt: Union[str, List[str]], **kwargs
) -> Tuple[torch.LongTensor, torch.LongTensor]:
) -> Tuple["torch.LongTensor", "torch.LongTensor"]:
kwargs["padding"] = True
kwargs["return_tensors"] = "pt"
output = self.tokenizer(prompt, **kwargs)
return output["input_ids"], output["attention_mask"]

def decode(self, token_ids: torch.LongTensor) -> List[str]:
def decode(self, token_ids: "torch.LongTensor") -> List[str]:
text = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
return text

Expand Down Expand Up @@ -127,13 +126,12 @@ def __init__(
self.model = model
self.tokenizer = TransformerTokenizer(tokenizer)

@torch.inference_mode
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
input_ids: "torch.LongTensor",
attention_mask: "torch.LongTensor",
past_key_values: Optional[Tuple] = None,
) -> Tuple[torch.FloatTensor, Optional[KVCacheType]]:
) -> Tuple["torch.FloatTensor", Optional[KVCacheType]]:
"""Compute a forward pass through the transformer model.

Parameters
Expand All @@ -151,28 +149,35 @@ def forward(
The computed logits and the new cached key and value tensors.

"""
try:
import torch
except ImportError:
ImportError(
"The `torch` library needs to be installed to use `transformers` models."
)
assert 0 < input_ids.ndim < 3

if past_key_values:
input_ids = input_ids[..., -1].unsqueeze(-1)

output = self.model(
input_ids,
attention_mask=attention_mask,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
past_key_values=past_key_values,
)
with torch.inference_mode():
output = self.model(
input_ids,
attention_mask=attention_mask,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
past_key_values=past_key_values,
)

return output.logits, output.past_key_values

def __call__(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
input_ids: "torch.LongTensor",
attention_mask: "torch.LongTensor",
past_key_values: Optional[Tuple] = None,
) -> torch.FloatTensor:
) -> "torch.FloatTensor":
logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values)
next_token_logits = logits[..., -1, :]

Expand Down
Loading
Loading