Skip to content

Commit

Permalink
fix potential bug with discriminator training in vqgan-vae
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 17, 2023
1 parent 1f09e65 commit a3e3a6d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion nuwa_pytorch/train_vqgan_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,9 @@ def train_step(self):
# update discriminator

if exists(self.vae.discr):
self.discr_optim.zero_grad()
discr_loss = 0

for _ in range(self.grad_accum_every):
img = next(self.dl)
img = img.to(device)
Expand All @@ -302,7 +304,6 @@ def train_step(self):
(loss / self.grad_accum_every).backward()

self.discr_optim.step()
self.discr_optim.zero_grad()

# log

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'nuwa-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.7.7',
version = '0.7.8',
license='MIT',
description = 'NÜWA - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit a3e3a6d

Please sign in to comment.