-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Softmax GAN example does not produce good looking digits #6134
Comments
Hi! Thanks for reporting this issue. I know what's the problem. Will try to fix it as soon as possible. |
@carmocca in the example, where should I call |
Before you compute the loss. I've updated your colab link. However, we will be rolling out a solution in #6147 |
I can't find |
This is the updated def training_step(self, batch, batch_idx, optimizer_idx):
imgs, _ = batch
# sample noise
z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
z = z.type_as(imgs)
# CHANGE
g_opt, d_opt = self.optimizers()
# train generator
if optimizer_idx == 0:
# generate images
self.generated_imgs = self(z)
# log sampled images
sample_imgs = self.generated_imgs[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('generated_images', grid, 0)
# ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)
# CHANGE
g_opt.zero_grad()
# adversarial loss is binary cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
tqdm_dict = {'g_loss': g_loss}
output = OrderedDict({
'loss': g_loss,
'progress_bar': tqdm_dict,
'log': tqdm_dict
})
return output
# train discriminator
if optimizer_idx == 1:
# Measure discriminator's ability to classify real from generated samples
# how well can it label as real?
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)
real_loss = self.adversarial_loss(self.discriminator(imgs), valid)
# how well can it label as fake?
fake = torch.zeros(imgs.size(0), 1)
fake = fake.type_as(imgs)
fake_loss = self.adversarial_loss(
self.discriminator(self(z).detach()), fake)
# CHANGE
d_opt.zero_grad()
# discriminator loss is the average of these
d_loss = (real_loss + fake_loss) / 2
tqdm_dict = {'d_loss': d_loss}
output = OrderedDict({
'loss': d_loss,
'progress_bar': tqdm_dict,
'log': tqdm_dict
})
return output |
Awesome, thanks! |
Btw. do you know how many epochs are really needed for good looking digits? |
Not really, sorry! Feel free to open a PR adding a comment about it if you find out 😉 |
@mtrencseni I ran the original example code (not your code on Google Colab), but, for an unknown reason, I didn't get a similar result to yours... The generated image below is at epoch 15 with batch_size=64. |
@mtrencseni Could you try the example (https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/generative_adversarial_net.py) again? If the problem still persists, please feel free to reopen this issue. |
🐛 Bug
I think the softmax GAN example is buggy, it doesn't produce good digits after 100-200 epochs.
This: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/generative_adversarial_net.py
To Reproduce
https://colab.research.google.com/drive/1T6TpBvtFt14UrvCwDgP3eIAzaz-Af_e9#scrollTo=8Dq7kWkVF31y
Expected behavior
Good looking digits are produced.
The text was updated successfully, but these errors were encountered: