Skip to content

Commit

Permalink
Support K-LMS in img2img (huggingface#270)
Browse files Browse the repository at this point in the history
* Support K-LMS in img2img

* Apply review suggestions
  • Loading branch information
anton-l committed Aug 29, 2022
1 parent da7d4cf commit efa773a
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 8 deletions.
34 changes: 28 additions & 6 deletions examples/inference/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
import torch

import PIL
from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, PNDMScheduler, UNet2DConditionModel
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
LMSDiscreteScheduler,
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
Expand Down Expand Up @@ -87,12 +94,17 @@ def __call__(
# get the original timestep using init_timestep
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
if isinstance(self.scheduler, LMSDiscreteScheduler):
timesteps = torch.tensor(
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
)
else:
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device)

# get prompt text embeddings
text_input = self.tokenizer(
Expand Down Expand Up @@ -133,8 +145,15 @@ def __call__(
latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
t_index = t_start + i
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[t_index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = latent_model_input.to(self.unet.dtype)
t = t.to(self.unet.dtype)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
Expand All @@ -145,11 +164,14 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs)["prev_sample"]
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents)
image = self.vae.decode(latents.to(self.vae.dtype))

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __call__(
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)

# predict the noise residual
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,9 @@ def step(
return {"prev_sample": prev_sample}

def add_noise(self, original_samples, noise, timesteps):
sigmas = self.match_shape(self.sigmas, noise)
noisy_samples = original_samples + noise * sigmas[timesteps]
sigmas = self.match_shape(self.sigmas[timesteps], noise)
noisy_samples = original_samples + noise * sigmas

return noisy_samples

def __len__(self):
Expand Down

0 comments on commit efa773a

Please sign in to comment.