diff --git a/nuwa_pytorch/train_vqgan_vae.py b/nuwa_pytorch/train_vqgan_vae.py index d1dd587..de71468 100644 --- a/nuwa_pytorch/train_vqgan_vae.py +++ b/nuwa_pytorch/train_vqgan_vae.py @@ -185,7 +185,8 @@ def __init__( random_split_seed = 42, ema_beta = 0.995, ema_update_after_step = 2000, - ema_update_every = 10 + ema_update_every = 10, + apply_grad_penalty_every = 4, ): super().__init__() assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE' @@ -247,6 +248,8 @@ def __init__( self.save_model_every = save_model_every self.save_results_every = save_results_every + self.apply_grad_penalty_every = apply_grad_penalty_every + self.results_folder = Path(results_folder) if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'): @@ -257,6 +260,7 @@ def __init__( def train_step(self): device = next(self.vae.parameters()).device steps = int(self.steps.item()) + apply_grad_penalty = not (steps % self.apply_grad_penalty_every) self.vae.train() @@ -270,7 +274,12 @@ def train_step(self): img = next(self.dl) img = img.to(device) - loss = self.vae(img, return_loss = True) + loss = self.vae( + img, + return_loss = True, + apply_grad_penalty = apply_grad_penalty + ) + accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) (loss / self.grad_accum_every).backward() diff --git a/nuwa_pytorch/vqgan_vae.py b/nuwa_pytorch/vqgan_vae.py index 07f2881..831fd9b 100644 --- a/nuwa_pytorch/vqgan_vae.py +++ b/nuwa_pytorch/vqgan_vae.py @@ -460,7 +460,8 @@ def forward( img, return_loss = False, return_discr_loss = False, - return_recons = False + return_recons = False, + apply_grad_penalty = False ): batch, channels, height, width, device = *img.shape, img.device assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}' @@ -485,11 +486,11 @@ def forward( fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img)) - gp = gradient_penalty(img, img_discr_logits) + loss = self.discr_loss(fmap_discr_logits, img_discr_logits) - discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits) - - loss = discr_loss + gp + if apply_grad_penalty: + gp = gradient_penalty(img, img_discr_logits) + loss = loss + gp if return_recons: return loss, fmap diff --git a/setup.py b/setup.py index 0dde3a8..c7eb78c 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name = 'nuwa-pytorch', packages = find_packages(exclude=[]), include_package_data = True, - version = '0.7.4', + version = '0.7.5', license='MIT', description = 'NÜWA - Pytorch', author = 'Phil Wang',