Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch optional import #821

Merged
merged 2 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

<img src="./docs/assets/images/logo.png" alt="Outlines Logo" width=300></img>

[![Twitter][twitter-badge]][twitter]
[![.txt Twitter][dottxt-twitter-badge]][dottxt-twitter]
[![Outlines Twitter][outlines-twitter-badge]][outlines-twitter]

[![Contributors][contributors-badge]][contributors]
[![Downloads][downloads-badge]][pypistats]
Expand Down Expand Up @@ -355,9 +356,11 @@ answer = outlines.generate.text(model)(prompt, max_tokens=100)

[contributors]: https://github.com/outlines-dev/outlines/graphs/contributors
[contributors-badge]: https://img.shields.io/github/contributors/outlines-dev/outlines?style=flat-square&logo=github&logoColor=white&color=ECEFF4
[twitter]: https://twitter.com/dottxtai
[dottxt-twitter]: https://twitter.com/dottxtai
[outlines-twitter]: https://twitter.com/OutlinesOSS
[discord]: https://discord.gg/R9DSu34mGd
[discord-badge]: https://img.shields.io/discord/1182316225284554793?color=81A1C1&logo=discord&logoColor=white&style=flat-square
[downloads-badge]: https://img.shields.io/pypi/dm/outlines?color=89AC6B&logo=python&logoColor=white&style=flat-square
[pypistats]: https://pypistats.org/packages/outlines
[twitter-badge]: https://img.shields.io/twitter/follow/dottxtai?style=social
[dottxt-twitter-badge]: https://img.shields.io/twitter/follow/dottxtai?style=social
[outlines-twitter-badge]: https://img.shields.io/twitter/follow/OutlinesOSS?style=social
4 changes: 2 additions & 2 deletions docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ Outlines supports OpenAI, transformers, Mamba, llama.cpp and exllama2 but **you
pip install openai
pip install transformers datasets accelerate torch
pip install llama-cpp-python
pip install exllamav2 torch
pip install mamba_ssm torch
pip install exllamav2 transformers torch
pip install mamba_ssm transformers torch
pip install vllm
```

Expand Down
4 changes: 4 additions & 0 deletions docs/reference/models/exllamav2.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# ExllamaV2

```bash
pip install exllamav2 transformers torch
```

*Coming soon*
4 changes: 4 additions & 0 deletions docs/reference/models/mamba.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Mamba

```bash
pip install mamba_ssm transformers torch
```

*Coming soon*
10 changes: 5 additions & 5 deletions outlines/models/exllamav2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import os
from typing import TYPE_CHECKING, Optional

import torch

if TYPE_CHECKING:
from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Lora
from transformers import PreTrainedTokenizer
import torch

from .transformers import TransformerTokenizer

Expand All @@ -28,8 +27,9 @@ def __init__(
self.past_seq = None
self.lora = lora

def forward(self, input_ids: torch.LongTensor, *_):
def forward(self, input_ids: "torch.LongTensor", *_):
"""Compute a forward pass through the exl2 model."""
import torch

# Caching with past_seq
reset = True
Expand Down Expand Up @@ -74,7 +74,7 @@ def forward(self, input_ids: torch.LongTensor, *_):
seq_tensor[-1:].view(1, -1), self.cache, loras=[self.lora]
)

def __call__(self, input_ids: torch.LongTensor, *_) -> torch.FloatTensor:
def __call__(self, input_ids: "torch.LongTensor", *_) -> "torch.FloatTensor":
logits = self.forward(input_ids)
next_token_logits = logits[..., -1, :]

Expand Down Expand Up @@ -169,7 +169,7 @@ def exl2(
from transformers import AutoTokenizer
except ImportError:
raise ImportError(
"The `exllamav2` library needs to be installed in order to use `exllamav2` models."
"The `exllamav2`, `transformers` and `torch` libraries needs to be installed in order to use `exllamav2` models."
)

# Load tokenizer
Expand Down
10 changes: 5 additions & 5 deletions outlines/models/mamba.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import TYPE_CHECKING, Optional

import torch

from .transformers import TransformerTokenizer

if TYPE_CHECKING:
import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import PreTrainedTokenizer

Expand All @@ -22,14 +21,14 @@ def __init__(
self.model = model
self.tokenizer = TransformerTokenizer(tokenizer)

def forward(self, input_ids: torch.LongTensor, *_):
def forward(self, input_ids: "torch.LongTensor", *_):
"""Compute a forward pass through the mamba model."""

output = self.model(input_ids)
next_token_logits = output.logits[..., -1, :]
return next_token_logits, None

def __call__(self, input_ids: torch.LongTensor, *_) -> torch.FloatTensor:
def __call__(self, input_ids: "torch.LongTensor", *_) -> "torch.FloatTensor":
return self.forward(input_ids)


Expand All @@ -40,11 +39,12 @@ def mamba(
tokenizer_kwargs: dict = {},
):
try:
import torch
from mamba_ssm import MambaLMHeadModel
from transformers import AutoTokenizer
except ImportError:
raise ImportError(
"The `mamba_ssm` library needs to be installed in order to use Mamba people."
"The `mamba_ssm`, `torch` and `transformer` libraries needs to be installed in order to use Mamba people."
)

if not torch.cuda.is_available():
Expand Down
Loading