Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Softmax GAN example does not produce good looking digits #6134

Closed
mtrencseni opened this issue Feb 22, 2021 · 11 comments
Closed

Softmax GAN example does not produce good looking digits #6134

mtrencseni opened this issue Feb 22, 2021 · 11 comments
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task waiting on author Waiting on user action, correction, or update

Comments

@mtrencseni
Copy link

mtrencseni commented Feb 22, 2021

🐛 Bug

I think the softmax GAN example is buggy, it doesn't produce good digits after 100-200 epochs.

This: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/generative_adversarial_net.py

To Reproduce

https://colab.research.google.com/drive/1T6TpBvtFt14UrvCwDgP3eIAzaz-Af_e9#scrollTo=8Dq7kWkVF31y

Expected behavior

Good looking digits are produced.

@mtrencseni mtrencseni added bug Something isn't working help wanted Open to be worked on labels Feb 22, 2021
@carmocca
Copy link
Contributor

Hi! Thanks for reporting this issue.

I know what's the problem. Will try to fix it as soon as possible.

@mtrencseni
Copy link
Author

@carmocca in the example, where should I call optimizer.zero_grad() to make it work?

@carmocca
Copy link
Contributor

Before you compute the loss. I've updated your colab link.

However, we will be rolling out a solution in #6147

@mtrencseni
Copy link
Author

mtrencseni commented Feb 24, 2021

I can't find zero_grad in there. Should it be in training_step()?
Btw. I don't think you can edit my colab file, no?
Sorry, I'm not an expert at PL (yet), I don't know how to retrieve the optimizer object in these callbacks.

@carmocca
Copy link
Contributor

This is the updated training_step

    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, _ = batch
        # sample noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(imgs)

        # CHANGE 
        g_opt, d_opt = self.optimizers()

        # train generator
        if optimizer_idx == 0:
            # generate images
            self.generated_imgs = self(z)
            # log sampled images
            sample_imgs = self.generated_imgs[:6]
            grid = torchvision.utils.make_grid(sample_imgs)
            self.logger.experiment.add_image('generated_images', grid, 0)
            # ground truth result (ie: all fake)
            # put on GPU because we created this tensor inside training_loop
            valid = torch.ones(imgs.size(0), 1)
            valid = valid.type_as(imgs)

            # CHANGE 
            g_opt.zero_grad()

            # adversarial loss is binary cross-entropy
            g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
            tqdm_dict = {'g_loss': g_loss}
            output = OrderedDict({
                'loss': g_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output
        # train discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples
            # how well can it label as real?
            valid = torch.ones(imgs.size(0), 1)
            valid = valid.type_as(imgs)
            real_loss = self.adversarial_loss(self.discriminator(imgs), valid)
            # how well can it label as fake?
            fake = torch.zeros(imgs.size(0), 1)
            fake = fake.type_as(imgs)
            fake_loss = self.adversarial_loss(
                self.discriminator(self(z).detach()), fake)

            # CHANGE 
            d_opt.zero_grad()

            # discriminator loss is the average of these
            d_loss = (real_loss + fake_loss) / 2
            tqdm_dict = {'d_loss': d_loss}
            output = OrderedDict({
                'loss': d_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

@mtrencseni
Copy link
Author

Awesome, thanks!

@mtrencseni
Copy link
Author

Btw. do you know how many epochs are really needed for good looking digits?

@carmocca
Copy link
Contributor

Not really, sorry! Feel free to open a PR adding a comment about it if you find out 😉

@mtrencseni
Copy link
Author

Doesn't work, this is what I get after 100 epochs:

softmax_gan

@akihironitta
Copy link
Contributor

akihironitta commented Feb 25, 2021

@mtrencseni I ran the original example code (not your code on Google Colab), but, for an unknown reason, I didn't get a similar result to yours...

The generated image below is at epoch 15 with batch_size=64.
20210225-182248

@edenlightning edenlightning added the waiting on author Waiting on user action, correction, or update label Mar 1, 2021
@akihironitta
Copy link
Contributor

@mtrencseni Could you try the example (https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/generative_adversarial_net.py) again? If the problem still persists, please feel free to reopen this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task waiting on author Waiting on user action, correction, or update
Projects
None yet
Development

No branches or pull requests

4 participants