Skip to content

Commit

Permalink
Allow user to choose device for models
Browse files Browse the repository at this point in the history
  • Loading branch information
BramVanroy authored and brandonwillard committed Sep 6, 2023
1 parent 8b324ea commit 4b83ee1
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
39 changes: 32 additions & 7 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer


__all__ = ["transformers"]


Expand All @@ -19,10 +18,9 @@ def __init__(
self,
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
device: Optional[str] = None,
):
self.device = device if device is not None else "cpu"
self.model = model.to(self.device)
self.device = model.device
self.model = model
self.tokenizer = tokenizer

def __call__(
Expand Down Expand Up @@ -86,15 +84,42 @@ def convert_token_to_string(self, token: str) -> str:
return string


def transformers(model_name: str, device: Optional[str] = None, **model_kwargs):
def transformers(
model_name: str,
device: Optional[str] = None,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
):
"""Instantiate a model from the `transformers` library and its tokenizer.
Parameters
----------
model_name
The name of the model as listed on Hugging Face's model page.
device_map
The device(s) on which the model should be loaded. This overrides
the value passed for `device_map` in `model_kwargs`.
model_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the model.
tokenizer_kwargs
A dictionary that contains the keyword arguments to pass to the
`from_pretrained` method when loading the tokenizer.
Returns
-------
A `TransformersModel` model instance.
"""
try:
from transformers import AutoModelForCausalLM
except ImportError:
raise ImportError(
"The `transformers` library needs to be installed in order to use `transformers` models."
)

model_kwargs["device_map"] = device
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
tokenizer = TransformersTokenizer(model_name)
tokenizer = TransformersTokenizer(model_name, **tokenizer_kwargs)

return Transformers(model, tokenizer, device)
return Transformers(model, tokenizer)
4 changes: 2 additions & 2 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def test_tokenizer():


def test_model():
with pytest.raises(RuntimeError, match="Expected one of cpu, cuda"):
with pytest.raises(ValueError, match="When passing device_map as a string"):
transformers(TEST_MODEL, device="non_existent")

model = transformers(TEST_MODEL, device="cpu")
assert isinstance(model.tokenizer, TransformersTokenizer)
assert model.device == "cpu"
assert model.device.type == "cpu"

input_ids = torch.tensor([[0, 1, 2]])
logits = model(input_ids, torch.ones_like(input_ids))
Expand Down

0 comments on commit 4b83ee1

Please sign in to comment.