Skip to content

Commit

Permalink
Make torch import optional
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Apr 16, 2024
1 parent 895a82e commit f6e5523
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 12 deletions.
4 changes: 2 additions & 2 deletions docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ Outlines supports OpenAI, transformers, Mamba, llama.cpp and exllama2 but **you
pip install openai
pip install transformers datasets accelerate torch
pip install llama-cpp-python
pip install exllamav2 torch
pip install mamba_ssm torch
pip install exllamav2 transformers torch
pip install mamba_ssm transformers torch
pip install vllm
```

Expand Down
4 changes: 4 additions & 0 deletions docs/reference/models/exllamav2.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# ExllamaV2

```bash
pip install exllamav2 transformers torch
```

*Coming soon*
4 changes: 4 additions & 0 deletions docs/reference/models/mamba.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Mamba

```bash
pip install mamba_ssm transformers torch
```

*Coming soon*
10 changes: 5 additions & 5 deletions outlines/models/exllamav2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import os
from typing import TYPE_CHECKING, Optional

import torch

if TYPE_CHECKING:
from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Lora
from transformers import PreTrainedTokenizer
import torch

from .transformers import TransformerTokenizer

Expand All @@ -28,8 +27,9 @@ def __init__(
self.past_seq = None
self.lora = lora

def forward(self, input_ids: torch.LongTensor, *_):
def forward(self, input_ids: "torch.LongTensor", *_):
"""Compute a forward pass through the exl2 model."""
import torch

# Caching with past_seq
reset = True
Expand Down Expand Up @@ -74,7 +74,7 @@ def forward(self, input_ids: torch.LongTensor, *_):
seq_tensor[-1:].view(1, -1), self.cache, loras=[self.lora]
)

def __call__(self, input_ids: torch.LongTensor, *_) -> torch.FloatTensor:
def __call__(self, input_ids: "torch.LongTensor", *_) -> "torch.FloatTensor":
logits = self.forward(input_ids)
next_token_logits = logits[..., -1, :]

Expand Down Expand Up @@ -169,7 +169,7 @@ def exl2(
from transformers import AutoTokenizer
except ImportError:
raise ImportError(
"The `exllamav2` library needs to be installed in order to use `exllamav2` models."
"The `exllamav2`, `transformers` and `torch` libraries needs to be installed in order to use `exllamav2` models."
)

# Load tokenizer
Expand Down
10 changes: 5 additions & 5 deletions outlines/models/mamba.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import TYPE_CHECKING, Optional

import torch

from .transformers import TransformerTokenizer

if TYPE_CHECKING:
import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import PreTrainedTokenizer

Expand All @@ -22,14 +21,14 @@ def __init__(
self.model = model
self.tokenizer = TransformerTokenizer(tokenizer)

def forward(self, input_ids: torch.LongTensor, *_):
def forward(self, input_ids: "torch.LongTensor", *_):
"""Compute a forward pass through the mamba model."""

output = self.model(input_ids)
next_token_logits = output.logits[..., -1, :]
return next_token_logits, None

def __call__(self, input_ids: torch.LongTensor, *_) -> torch.FloatTensor:
def __call__(self, input_ids: "torch.LongTensor", *_) -> "torch.FloatTensor":
return self.forward(input_ids)


Expand All @@ -40,11 +39,12 @@ def mamba(
tokenizer_kwargs: dict = {},
):
try:
import torch
from mamba_ssm import MambaLMHeadModel
from transformers import AutoTokenizer
except ImportError:
raise ImportError(
"The `mamba_ssm` library needs to be installed in order to use Mamba people."
"The `mamba_ssm`, `torch` and `transformer` libraries needs to be installed in order to use Mamba people."
)

if not torch.cuda.is_available():
Expand Down

0 comments on commit f6e5523

Please sign in to comment.