Skip to content

Commit

Permalink
acgan: Add batch normalization to the Generator, etc (#8616)
Browse files Browse the repository at this point in the history
* Add batch normalization to the Generator

* Add name of a paper
  • Loading branch information
ozabluda authored and fchollet committed Jan 19, 2018
1 parent 58cf550 commit 97acd91
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions examples/mnist_acgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from keras.datasets import mnist
from keras import layers
from keras.layers import Input, Dense, Reshape, Flatten, Embedding, Dropout
from keras.layers import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2DTranspose, Conv2D
from keras.models import Sequential, Model
Expand All @@ -56,11 +57,13 @@ def build_generator(latent_size):
cnn.add(Conv2DTranspose(192, 5, strides=1, padding='valid',
activation='relu',
kernel_initializer='glorot_normal'))
cnn.add(BatchNormalization())

# upsample to (14, 14, ...)
cnn.add(Conv2DTranspose(96, 5, strides=2, padding='same',
activation='relu',
kernel_initializer='glorot_normal'))
cnn.add(BatchNormalization())

# upsample to (28, 28, ...)
cnn.add(Conv2DTranspose(1, 5, strides=2, padding='same',
Expand Down Expand Up @@ -124,7 +127,7 @@ def build_discriminator():
if __name__ == '__main__':

# batch and latent size taken from the paper
epochs = 50
epochs = 100
batch_size = 100
latent_size = 100

Expand Down Expand Up @@ -215,8 +218,10 @@ def build_discriminator():

x = np.concatenate((image_batch, generated_images))

# use soft real/fake labels
soft_zero, soft_one = 0.1, 0.9
# use one-sided soft real/fake labels
# Salimans et al., 2016
# https://arxiv.org/pdf/1606.03498.pdf (Section 3.4)
soft_zero, soft_one = 0, 0.95
y = np.array([soft_one] * batch_size + [soft_zero] * batch_size)
aux_y = np.concatenate((label_batch, sampled_labels), axis=0)

Expand Down Expand Up @@ -286,7 +291,7 @@ def build_discriminator():
'component', *discriminator.metrics_names))
print('-' * 65)

ROW_FMT = '{0:<22s} | {1:<4.2f} | {2:<15.2f} | {3:<5.2f}'
ROW_FMT = '{0:<22s} | {1:<4.2f} | {2:<15.4f} | {3:<5.4f}'
print(ROW_FMT.format('generator (train)',
*train_history['generator'][-1]))
print(ROW_FMT.format('generator (test)',
Expand Down

0 comments on commit 97acd91

Please sign in to comment.