Skip to content

Commit

Permalink
Vectorize HF model calls
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jun 6, 2023
1 parent b858751 commit e3cdf0e
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 29 deletions.
28 changes: 23 additions & 5 deletions outlines/models/hf_diffusers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""Integration with HuggingFace's `diffusers` library."""
import functools
from typing import List, Union

import numpy as np
from PIL.Image import Image as PILImage

import outlines


def HuggingFaceDiffuser(model_name: str) -> PILImage:
"""Create a function that will call a stable diffusion pipeline.
Expand All @@ -12,17 +18,20 @@ def HuggingFaceDiffuser(model_name: str) -> PILImage:
"""

def call(prompt: str, samples: int = 1) -> str:
def call(prompt: Union[str, List[str]], samples: int = 1) -> str:
if isinstance(prompt, str):
prompt = [prompt]

results = call_stable_diffusion_pipeline(model_name, prompt, samples)
if samples == 1:
return results[0]

return results

return call


@functools.partial(outlines.vectorize, signature="(),(m),()->(m,s)")
def call_stable_diffusion_pipeline(
model_name: str, prompt: str, samples: int
model_name: str, prompt: List[str], samples: int
) -> PILImage:
"""Build and call the Stable Diffusion pipeline.
Expand All @@ -31,10 +40,19 @@ def call_stable_diffusion_pipeline(
import torch
from diffusers import StableDiffusionPipeline

# Pipelines don't accept NumPy arrays
prompt = list(prompt)

pipe = StableDiffusionPipeline.from_pretrained(model_name)
if torch.cuda.is_available():
pipe = pipe.to("cuda")

images = pipe(prompt, num_images_per_prompt=samples).images
if not isinstance(images, list):
images = [images]

array = np.empty((samples,), dtype="object")
for idx, image in enumerate(images):
array[idx] = image

return images
return np.atleast_2d(array)
48 changes: 32 additions & 16 deletions outlines/models/hf_transformers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Integration with HuggingFace's `transformers` library."""
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
import functools
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

from outlines.caching import cache
import numpy as np

import outlines

if TYPE_CHECKING:
import torch
Expand Down Expand Up @@ -47,34 +50,47 @@ def HuggingFaceCompletion(
temperature = 1.0

def call(
prompt: str,
prompt: Union[str, List[str]],
*,
samples: int = 1,
stop_at: Optional[List[str]] = None,
is_in: Optional[List[str]] = None,
stop_at: List[Optional[str]] = [],
is_in: List[Optional[str]] = [],
type: Optional[str] = None,
) -> str:
if isinstance(prompt, str):
prompt = [prompt]

return call_model_generate_method(
model_name, prompt, max_tokens, temperature, samples, stop_at, is_in, type
model_name,
prompt,
max_tokens,
temperature,
samples,
stop_at,
is_in,
type,
)

return call


@cache
@functools.partial(outlines.vectorize, signature="(),(m),(),(),(),(i),(j),()->(m,s)")
def call_model_generate_method(
model_name: str,
prompt: str,
max_tokens: int,
temperature: float,
samples: int,
stop_at: List[str],
is_in: List[str],
stop_at: List[Optional[str]],
is_in: np.ndarray,
type: str,
) -> str:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# `generate` does not accept NumPy arrays
prompt = list(prompt)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

Expand All @@ -88,7 +104,7 @@ def call_model_generate_method(
raise NotImplementedError(
"It is currently not possible to control the generation of several samples with the `transformers` integration"
)
if is_in is not None:
if is_in.size > 0:
raise ValueError(
"You cannot both restrict to a set of choices with `is_in` and to a type with `type`"
)
Expand All @@ -97,12 +113,12 @@ def call_model_generate_method(
)
logit_processors = [logit_processor]
stopping_criteria = [stopping_criterion]
elif is_in is not None:
elif is_in.size > 0:
if samples > 1:
raise NotImplementedError(
"It is currently not possible to control the generation of several samples with the `transformers` integration"
)
if stop_at is not None:
if stop_at.size > 0:
raise ValueError(
"You cannot both restrict to a set of choices with `is_in` and set a stopping criterion"
)
Expand All @@ -111,7 +127,7 @@ def call_model_generate_method(
)
logit_processors = [logit_processor]
stopping_criteria = [stopping_criterion]
elif stop_at is not None:
elif stop_at.size > 0:
if samples > 1:
raise NotImplementedError(
"It is currently not possible to control the generation of several samples with the `transformers` integration"
Expand All @@ -132,7 +148,7 @@ def call_model_generate_method(
temperature=temperature,
max_new_tokens=max_tokens,
pad_token_id=tokenizer.eos_token_id,
num_return_sequences=samples,
num_return_sequences=int(samples),
logits_processor=logit_processors,
stopping_criteria=stopping_criteria,
)
Expand All @@ -141,11 +157,11 @@ def call_model_generate_method(

if samples == 1:
results = tokenizer.decode(new_tokens, skip_special_tokens=True)
results = postprocessing(results)
results = [postprocessing(results)]
else:
results = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)

return results
return np.atleast_2d(results)


def create_stop_constraint(
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_hf_diffusers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from PIL.Image import Image as PILImage

from outlines.models.hf_diffusers import HuggingFaceDiffuser
Expand All @@ -12,7 +13,7 @@ def test_stable_diffusion():
assert isinstance(image, PILImage)

images = model("test", samples=3)
assert isinstance(images, list)
assert isinstance(images, np.ndarray)
assert len(images) == 3
for img in images:
assert isinstance(image, PILImage)
11 changes: 4 additions & 7 deletions tests/models/test_hf_transformers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import outlines
import numpy as np
import pytest

outlines.disable_cache()

import pytest # noqa

from outlines.models.hf_transformers import HuggingFaceCompletion # noqa
from outlines.models.hf_transformers import HuggingFaceCompletion

TEST_MODEL = "hf-internal-testing/tiny-random-GPTJForCausalLM"

Expand All @@ -19,7 +16,7 @@ def test_samples():
assert isinstance(answer, str)

answers = model("test", samples=3)
assert isinstance(answers, list)
assert isinstance(answers, np.ndarray)
assert len(answers) == 3


Expand Down

0 comments on commit e3cdf0e

Please sign in to comment.