Skip to content

Commit

Permalink
Refactor the HuggingFace diffuser connector
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Apr 13, 2023
1 parent b856206 commit 50dc834
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 32 deletions.
22 changes: 10 additions & 12 deletions outlines/image.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
from outlines.text import prompt
from typing import Any, Callable, Dict, List

from PIL.Image import Image as PILImage

def generation(name: str):
"""Decorator that allows to simplify calls to image generation models."""
provider_name = name.split("/")[0]
model_name = name[len(provider_name) + 1 :]
import outlines.models.routers as routers
from outlines.text import prompt

if provider_name == "hf":
from outlines.image.models.hugging_face import HFDiffuser

generative_model = HFDiffuser(model_name) # type:ignore
else:
raise NameError(f"The model provider {provider_name} is not available.")
def generation(model_path: str) -> Callable:
"""Decorator that allows to simplify calls to image generation models."""
generative_model_builder = routers.image_generation(model_path)
generative_model = generative_model_builder()

def decorator(fn):
def decorator(fn: Callable):
prompt_fn = prompt(fn)

def wrapper(*args, **kwargs):
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> PILImage:
"""Call the Diffuser with the rendered template.
Returns
Expand Down
1 change: 1 addition & 0 deletions outlines/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .hf_diffusers import HuggingFaceDiffuser
from .hf_transformers import HuggingFaceCompletion
from .openai import OpenAICompletion
30 changes: 16 additions & 14 deletions outlines/models/hf_diffusers.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
try:
from diffusers import StableDiffusionPipeline
except ImportError:
raise ImportError(
"You need to install `torch` and `diffusers` to run the StableDiffusion model."
)
from PIL.Image import Image as PILImage


class HFDiffuser:
"""A `StableDiffusion` distributed random image."""
def HuggingFaceDiffuser(model_name: str) -> PILImage:
"""Create a function that will call a stable diffusion pipeline.
def __init__(self, model_name: str):
self.model_name = model_name
Parameters
----------
model_name: str
The name of the model as listed on HuggingFace's models page.
def __call__(self, prompt: str) -> str:
"""Use HuggingFace's `StableDiffusion` pipeline to sample a new image."""
pipe = StableDiffusionPipeline.from_pretrained(self.model_name)
pipe = pipe.to("cuda")
"""

def call(prompt: str) -> str:
import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(model_name)
if torch.cuda.is_available():
pipe = pipe.to("cuda")
image = pipe(prompt).images[0]

return image
34 changes: 28 additions & 6 deletions outlines/models/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@
def language_completion(model_path: str) -> Callable:
"""Return the model and model name corresponding to the model path.
Note
----
We return both the model builder and the model name instead of partially
applying the model name to the model builder
Parameters
----------
model_path
Expand All @@ -40,6 +34,34 @@ def language_completion(model_path: str) -> Callable:
return functools.partial(model, model_name)


def image_generation(model_path: str) -> Callable:
"""Return the model and model name corresponding to the model path.
Parameters
----------
model_path
A string of the form "model_provider/model_name"
Returns
-------
The model builder with bound model name.
"""

registry: Dict[str, Callable] = {
"hf": models.HuggingFaceDiffuser,
}

provider, model_name = parse_model_path(model_path)

try:
model = registry[provider]
except KeyError:
raise ValueError(f"The model provider {provider} is not available.")

return functools.partial(model, model_name)


def parse_model_path(model_path: str) -> Tuple[str, str]:
"""Parse a model path in the form 'provider/model_name'"""

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ module = [
"diffusers",
"jinja2",
"openai",
"PIL.Image",
"pytest",
"torch",
"transformers",
Expand Down

0 comments on commit 50dc834

Please sign in to comment.