-
Notifications
You must be signed in to change notification settings - Fork 414
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor the HuggingFace
diffuser
connector
- Loading branch information
Showing
5 changed files
with
56 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .hf_diffusers import HuggingFaceDiffuser | ||
from .hf_transformers import HuggingFaceCompletion | ||
from .openai import OpenAICompletion |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,23 @@ | ||
try: | ||
from diffusers import StableDiffusionPipeline | ||
except ImportError: | ||
raise ImportError( | ||
"You need to install `torch` and `diffusers` to run the StableDiffusion model." | ||
) | ||
from PIL.Image import Image as PILImage | ||
|
||
|
||
class HFDiffuser: | ||
"""A `StableDiffusion` distributed random image.""" | ||
def HuggingFaceDiffuser(model_name: str) -> PILImage: | ||
"""Create a function that will call a stable diffusion pipeline. | ||
def __init__(self, model_name: str): | ||
self.model_name = model_name | ||
Parameters | ||
---------- | ||
model_name: str | ||
The name of the model as listed on HuggingFace's models page. | ||
def __call__(self, prompt: str) -> str: | ||
"""Use HuggingFace's `StableDiffusion` pipeline to sample a new image.""" | ||
pipe = StableDiffusionPipeline.from_pretrained(self.model_name) | ||
pipe = pipe.to("cuda") | ||
""" | ||
|
||
def call(prompt: str) -> str: | ||
import torch | ||
from diffusers import StableDiffusionPipeline | ||
|
||
pipe = StableDiffusionPipeline.from_pretrained(model_name) | ||
if torch.cuda.is_available(): | ||
pipe = pipe.to("cuda") | ||
image = pipe(prompt).images[0] | ||
|
||
return image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,6 +51,7 @@ module = [ | |
"diffusers", | ||
"jinja2", | ||
"openai", | ||
"PIL.Image", | ||
"pytest", | ||
"torch", | ||
"transformers", | ||
|