Skip to content

Commit

Permalink
actually use classifier free guidance during sampling, using the cond…
Browse files Browse the repository at this point in the history
…_scale > 1
  • Loading branch information
lucidrains committed Apr 13, 2022
1 parent 4303661 commit 174a896
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'video-diffusion-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.3',
version = '0.0.4',
license='MIT',
description = 'Video Diffusion - Pytorch',
author = 'Phil Wang',
Expand Down
17 changes: 8 additions & 9 deletions video_diffusion_pytorch/video_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ def __init__(

def forward_with_cond_scale(self, *args, cond_scale = 2., **kwargs):
logits = self.forward(*args, null_cond_prob = 0., **kwargs)

if cond_scale == 1:
return logits

Expand Down Expand Up @@ -434,8 +433,8 @@ def q_posterior(self, x_start, x_t, t):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped

def p_mean_variance(self, x, t, clip_denoised: bool, cond = None):
x_recon = self.predict_start_from_noise(x, t=t, noise = self.denoise_fn(x, t, cond = cond))
def p_mean_variance(self, x, t, clip_denoised: bool, cond = None, cond_scale = 1.):
x_recon = self.predict_start_from_noise(x, t=t, noise = self.denoise_fn.forward_with_cond_scale(x, t, cond = cond, cond_scale = cond_scale))

if clip_denoised:
x_recon.clamp_(-1., 1.)
Expand All @@ -444,32 +443,32 @@ def p_mean_variance(self, x, t, clip_denoised: bool, cond = None):
return model_mean, posterior_variance, posterior_log_variance

@torch.no_grad()
def p_sample(self, x, t, cond = None, clip_denoised = True, repeat_noise = False):
def p_sample(self, x, t, cond = None, cond_scale = 1., clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, clip_denoised = clip_denoised, cond = cond)
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, clip_denoised = clip_denoised, cond = cond, cond_scale = cond_scale)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

@torch.no_grad()
def p_sample_loop(self, shape, cond = None):
def p_sample_loop(self, shape, cond = None, cond_scale = 1.):
device = self.betas.device

b = shape[0]
img = torch.randn(shape, device=device)

for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), cond = cond)
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), cond = cond, cond_scale = cond_scale)
return img

@torch.no_grad()
def sample(self, cond = None, batch_size = 16):
def sample(self, cond = None, cond_scale = 1., batch_size = 16):
batch_size = cond.shape[0] if exists(cond) else batch_size
image_size = self.image_size
channels = self.channels
num_frames = self.num_frames
return self.p_sample_loop((batch_size, channels, num_frames, image_size, image_size), cond = cond)
return self.p_sample_loop((batch_size, channels, num_frames, image_size, image_size), cond = cond, cond_scale = cond_scale)

@torch.no_grad()
def interpolate(self, x1, x2, t = None, lam = 0.5):
Expand Down

0 comments on commit 174a896

Please sign in to comment.