Skip to content

Commit

Permalink
only apply gradient penalty every 4th step, like stylegan
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 28, 2022
1 parent 177b27e commit 707c5fc
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
13 changes: 11 additions & 2 deletions nuwa_pytorch/train_vqgan_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ def __init__(
random_split_seed = 42,
ema_beta = 0.995,
ema_update_after_step = 2000,
ema_update_every = 10
ema_update_every = 10,
apply_grad_penalty_every = 4,
):
super().__init__()
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
Expand Down Expand Up @@ -247,6 +248,8 @@ def __init__(
self.save_model_every = save_model_every
self.save_results_every = save_results_every

self.apply_grad_penalty_every = apply_grad_penalty_every

self.results_folder = Path(results_folder)

if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
Expand All @@ -257,6 +260,7 @@ def __init__(
def train_step(self):
device = next(self.vae.parameters()).device
steps = int(self.steps.item())
apply_grad_penalty = not (steps % self.apply_grad_penalty_every)

self.vae.train()

Expand All @@ -270,7 +274,12 @@ def train_step(self):
img = next(self.dl)
img = img.to(device)

loss = self.vae(img, return_loss = True)
loss = self.vae(
img,
return_loss = True,
apply_grad_penalty = apply_grad_penalty
)

accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

(loss / self.grad_accum_every).backward()
Expand Down
11 changes: 6 additions & 5 deletions nuwa_pytorch/vqgan_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,8 @@ def forward(
img,
return_loss = False,
return_discr_loss = False,
return_recons = False
return_recons = False,
apply_grad_penalty = False
):
batch, channels, height, width, device = *img.shape, img.device
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
Expand All @@ -485,11 +486,11 @@ def forward(

fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))

gp = gradient_penalty(img, img_discr_logits)
loss = self.discr_loss(fmap_discr_logits, img_discr_logits)

discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)

loss = discr_loss + gp
if apply_grad_penalty:
gp = gradient_penalty(img, img_discr_logits)
loss = loss + gp

if return_recons:
return loss, fmap
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'nuwa-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.7.4',
version = '0.7.5',
license='MIT',
description = 'NÜWA - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 707c5fc

Please sign in to comment.