Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

acgan: Add batch normalization to the Generator, etc #8616

Merged
merged 2 commits into from
Jan 19, 2018

Conversation

ozabluda
Copy link
Contributor

  1. Add batch normalization to the Generator. This makes the example closer to the referenced paper and improves generated images. Adding batch normalization to the Discriminator breaks training so badly, that I suspect a bug (maybe Add tests exposing BatchNormalization bug(s)? when used in GANs. #5647 is fixed incompletely or something). Not adding batch normalization to the Discriminator also side-steps the issue of correlation of samples within a batch (https://github.com/soumith/ganhacks#4-batchnorm)

  2. Use one-sided soft labels and a harder soft_one=0.95 vs 0.9). The referenced paper says they don't need one-sided soft labels. This example also doesn't "need" them any more, but the generated images are better. Add reference to a paper.

  3. Increase epoch=100 from 50, as good images often appear between epochs 50 and 100. Note that the training time per epoch is half that of the original example, after 67cd3b0

  4. Increase output precision of various losses from 2 decimal digits to 4. You can't really tell what is going on with just 2.

@lukedeo , I see a lot of examples online which use embedding with Hadamard, but do you know of any paper(s) we can reference? I haven't seen it in any of the GAN papers. I really like embedding with Hadamard, as replacing them would require multiple (3-5?) additional layers, but to be thorough I did a half-hearted attempts to remove them (make closer to the acgan paper), just to see if I can, and failed (generated images are much worse).

@ozabluda
Copy link
Contributor Author

Generated images are better than previous best (#8482 (comment))

The biggest difference is much larger diversity in line thickness.

Epoch 57
plot_epoch_057_generated

Epoch 99
plot_epoch_099_generated

Epoch 100
plot_epoch_100_generated

@fchollet
Copy link
Member

fchollet commented Dec 3, 2017

Note that if you want to completely freeze a model that has BN layers (like here) you need to do more than just set it to non-trainable, you need to also disable its updates. Otherwise the batch statistics will still get updated during training. This is not a bug, this is simply a manifestation of the fact that non-backprop updates and layer trainability are independent (for instance, if you set a stateful RNN to non-trainable, that will not freeze its state).

It is possible that we should have a frozen attribute that would both disable trainability and other updates. Or some other API to freeze BN layers.

@alpapado
Copy link

alpapado commented Dec 4, 2017

you need to also disable its updates.

Could you elaborate on that please? Do you mean setting every layer's trainable attribute to False?

@fchollet
Copy link
Member

fchollet commented Dec 4, 2017

No. You'll need to clear the attribute _per_input_updates on every sublayer. That, or find other way to get model.updates to return [].

@ozabluda
Copy link
Contributor Author

ozabluda commented Dec 4, 2017

I don't think it ever makes sense to freeze those 2 trainable parameters, while updating the 2 non-trainable parameters. This approach should properly be called "Breaking Deep Network Training by Introducing Uncorrectable Internal Covariate Shift". In terms of calling things trainable or non-trainable, I don't see what difference does it make conceptually whether the parameter updates happen during forward prop or backprop. I think those non-trainable weights should always be frozen as well, and renamed _trainable_when they aren't frozen.

This issue does not affect this PR (which has BN only on the Generator, which is never frozen), but may be the reason why my attempts to put BN into the Discriminator broke things very badly. The following workaround #4762 (comment) seems to be working for the GAN use case, and I'll try it, but not for a while.

@fchollet
Copy link
Member

fchollet commented Dec 4, 2017

@ozabluda Jeremy Howard did a study on this recently ("should we update the batch statistics during fine-tuning?") and the answer was, "it depends". It's not clear-cut.

Beyond batch norm specifically, you seem confused by the different between trainability and stateful behavior. Setting a layer to non-trainable means its trainable weights will not be taken into account during training. It does not affect the parts of the layer's state that are independent from training. For instance, a layer that maintains a counter that is incremented by one with every batch, will not stop doing that if you set trainable = False. Because that has nothing to do with training. Same with stateful RNNs, or with BN updates.

If you want to run your BN layers in inference mode, the way to do it is to pass a static boolean as the training argument:

y = BatchNormalization()(x, training=False)

@lukedeo
Copy link
Contributor

lukedeo commented Dec 4, 2017

unrelated, @ozabluda I don't know of any references. Tbh, it always just seemed more natural on an intuitive basis.

@ozabluda
Copy link
Contributor Author

ozabluda commented Dec 4, 2017

@fchollet, Can't find that study by Jeremy Howard.

Is there a recommended way to freeze those two non_trainable weights that would work for the acgan example, if I put BatchNormalization into the Discriminator? y = BatchNormalization()(x, training=False) will work for the combined model, but not for the Discriminator. For the same reason, the idea from my previous comment will not work either.

Right now, the best advice seems to be to go through the discriminator model, immediately after discriminator.trainable = False is set, looking for BatchNormalization layers and set their per_input_updates={}. Is that supposed to work without screwing up the already compiled discriminator?

Also see tensorflow/tensorflow#10857

@ahundt
Copy link
Contributor

ahundt commented Dec 5, 2017

It is possible that we should have a frozen attribute that would both disable trainability and other updates. Or some other API to freeze BN layers.

Yes, please!

@fchollet
Copy link
Member

fchollet commented Dec 5, 2017

Is there a recommended way to freeze those two non_trainable weights that would work for the acgan example

There is no "clean" way at the moment, but we need one. The simplest would be a layer/model attribute that regulates whether or not the layer/model will always return [] when asked for .updates. It would be very straightforward to implement. Could be named frozen, freeze, freeze_updates...

To make sure we have separation of concerns between trainability and updates, it's probably best for this attribute to only act on updates (no effect on trainability) and to explicit mention updates in the name. Like, freeze_updates.

@ozabluda
Copy link
Contributor Author

ozabluda commented Dec 5, 2017

FWIW, none of the following works to freeze those 2 weights in BatchNormalization I put into Discriminator in acgan example:

Immediately after discriminator.trainable = False

    for layer in discriminator.layers[1].layers:
        if layer.name.startswith('batch_normalization_'):
            layer._per_input_updates={}
            layer._updates=[]

before and after combined.compile():

    for layer in combined.layers[-1].layers[1].layers:
        if layer.name.startswith('batch_normalization_'):
            layer._per_input_updates={}
            layer._updates=[]

@ahundt
Copy link
Contributor

ahundt commented Dec 6, 2017

@fchollet I agree, freeze_updates sounds like the best of your suggestions

@ozabluda
Copy link
Contributor Author

ozabluda commented Dec 6, 2017

@ahundt, I have no immediate plans to work on it, since I don't understand that part of the Keras code, and therefore the suggestions (as of now). I also don't understand the performance implications, for example compared to native tensorflow/tensorflow#12580.

So if you feel like doing it, it would be great. For the reference, this is how I check if freezing actually works in the acgan example after putting BatchNormalization into the Discriminator, (not in this PR), immediately after combined.train_on_batch()

print(combined.layers[-1].layers[1].layers[5].name)
print(np.all([
    np.array_equal(w1, w2) for w1,w2 in
    zip(weights,
        combined.layers[-1].layers[1].layers[5].get_weights())]))

print(np.array_equal(
    weights[2],
    combined.layers[-1].layers[1].layers[5].get_weights()[2]))#0-1 are equal, 2-3 are not

A much simpler example can't be made, because this is the use case for which none of the hacks, etc work, but see example in #8676 (put BatchNormalization into m2).

@ahundt
Copy link
Contributor

ahundt commented Dec 7, 2017

@ozabluda sorry, meant to @ mention fchollet on that previous post, I was addressing the feature he was suggesting. I edited it in now.

@ahundt
Copy link
Contributor

ahundt commented Dec 18, 2017

@fchollet Also, I'm not sure I mentioned this, but freezing batch normalization updates is quite helpful when fine tuning segmentation problems from pre-trained weights. I think something like this modified BatchNormalization class with freeze would do the trick. Would it be acceptable to place the freeze parameter in allowed_kwargs under Layer?

@fchollet
Copy link
Member

fchollet commented Jan 2, 2018

For the record on master you can now set layer.updatable = False or model.updatable = False to freeze updates (mirrors trainable = False).

@lukedeo would you like to review this PR?

@lukedeo
Copy link
Contributor

lukedeo commented Jan 3, 2018

yep @fchollet, can check by EOW

@fchollet
Copy link
Member

@ahundt @ozabluda FYI I think the updatable system is overcomplicated. Pragmatically, what most people want/expect is that setting bn.trainable = False will run BN in inference mode, i.e. without updates. Likewise, pragmatically, updates when trainable == False only make sense for stateful layers (stateful RNNs being the only example of that).

Thus I am reverting updatable (it was never part of a release) and simply disabling updates for non-stateful layers when trainable == False.

@fchollet
Copy link
Member

@lukedeo if you're still available to review this PR, we're waiting for your input. Otherwise, please tell and we will find another reviewer. Thanks!

@ahundt
Copy link
Contributor

ahundt commented Jan 11, 2018

@fchollet thanks for the update, the new functionality you describe is the behavior I imagined when I first saw the name of the flag a year ago when I was first learning the code base. I look forward to using it!

@ozabluda
Copy link
Contributor Author

For the reference, the changes described in #8616 (comment) were made in 24246ea

@ahundt
Copy link
Contributor

ahundt commented Jan 12, 2018

Specifically, the reverse commit that makes trainable=False prevent updates is 24246ea

@lukedeo
Copy link
Contributor

lukedeo commented Jan 12, 2018

Hey @fchollet back from vacation. Will look today - sorry.

@lukedeo
Copy link
Contributor

lukedeo commented Jan 13, 2018

@fchollet this seems fine to me overall.

@ozabluda one question, though not necessarily needed for this PR, is if we should also add BN to the discriminator since it's technically closer to the paper. If we see good performance, which we do, we don't necessarily need it, but just raising as people might question now that I think about it. The one place I could see this maybe helping is if people want to adapt this script to CIFAR10 or that ilk.

@ozabluda
Copy link
Contributor Author

I tried adding BN to the discriminator (the very first message in this PR, item 1). With recent fixes to the BN, maybe I'll try it again later. An attempt to adapt it to CIFAR-10 is going on here: #8937

@ozabluda
Copy link
Contributor Author

ozabluda commented Jan 16, 2018

I've created animations of acgan training:
https://www.youtube.com/playlist?list=PL7zaNUNu3zI8O6cBK6rgpLCEXtZerDzU_

Resolutions are 1080p (2K), 1440p (3K), 2160p (4K) 7680p (8K). There are two types of video
acgan1: one frame per epoch.
acgan0: one frame per iteration.

Note that on YouTube, you can change the playback speed, and, when paused, go frame-by-frame with ',' and '.'

Real images are on the bottom. Fake generated one are on top. Epoch/iteration is in the gray bar in between. For every column of 10 digits (0-9), latent noise vector is the same (class is different). For every row of same digits, latent noise vector is different (class is the same).

Discriminator probability real/fake is shown as a grayscale square around a digit. It's scaled such that p=0.5 is pure black, p=0.75 is pure white, with p<0.5 and p>0.75 clipped to min and max values. Otherwise, one wouldn't be able to see what is going on.

Staring at these videos for a while, falsified a lot of my preconceived notions of what GAN is actually doing. Too many to describe here. I highly recommend anyone who is interested to stare for a while as well.

@ahundt
Copy link
Contributor

ahundt commented Jan 16, 2018

@ozabluda can you fix the youtube link? doesn't seem to work for me

@ozabluda
Copy link
Contributor Author

@ahundt, fixed.

@ahundt
Copy link
Contributor

ahundt commented Jan 18, 2018

thanks! looks neat

@ozabluda
Copy link
Contributor Author

Staring at these videos for a while, falsified a lot of my preconceived notions of what GAN is actually doing. Too many to describe here. I highly recommend anyone who is interested to stare for a while as well.

Two quick things:

  1. In this GAN, diversity doesn't come from training (after first epochs). Most It comes from sampling latent noise. Which is not what we want, but may not matter much in practice, depending.
  2. I blame everything on the discriminator now, unlike the consensus, and what I though before I started at these videos for a while.

P.S. I've added acgan0 8K video - 30 min, 100 GB, baby!

@ahundt
Copy link
Contributor

ahundt commented Jan 19, 2018

P.S. I've added acgan0 8K video - 30 min, 100 GB, baby!

I present you with:

🥇

Awarded for the highest resolution 28x28 image of all time. 👍

P.S. Any hints on what rendering tools you used to generate these? Sounds like something useful

@ozabluda
Copy link
Contributor Author

just ffmpeg. the code to generate pngs is an embarrassing mess, not sure if it's worthwhile to clean and add it to the example.

@fchollet fchollet merged commit 97acd91 into keras-team:master Jan 19, 2018
@ozabluda ozabluda deleted the patch-8 branch January 19, 2018 20:33
@ahundt
Copy link
Contributor

ahundt commented Jan 20, 2018

A png utility could go in some non-keras repo or keras-contrib.

@Evan05543071
Copy link

Evan05543071 commented Mar 24, 2019

Specifically, the reverse commit that makes trainable=False prevent updates is 24246ea

@ahundt
Does this mean:
I found that if I set the trainable= False (batch_normalization)
All parameters in batch normalization layer would stop updating
However,if I use what you said up there set the training =False
The moving mean and moving variance would be fixed and the others would still change while training

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants