Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change BN layer to use moving mean/var if frozen #9965

Closed
wants to merge 10 commits into from
5 changes: 4 additions & 1 deletion keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def in_train_phase(x, alt, training=None):
training = learning_phase()
uses_learning_phase = True
else:
uses_learning_phase = False
uses_learning_phase = getattr(training, '_uses_learning_phase', False)

# CNTK currently don't support cond op, so here we use
# element_select approach as workaround. It may have
Expand Down Expand Up @@ -323,6 +323,9 @@ def is_sparse(tensor):


def int_shape(x):
if type(x) in {int, float}:
return ()

if hasattr(x, '_keras_shape'):
return x._keras_shape

Expand Down
2 changes: 1 addition & 1 deletion keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2865,7 +2865,7 @@ def in_train_phase(x, alt, training=None):
training = learning_phase()
uses_learning_phase = True
else:
uses_learning_phase = False
uses_learning_phase = getattr(training, '_uses_learning_phase', False)

if training is 1 or training is True:
if callable(x):
Expand Down
2 changes: 1 addition & 1 deletion keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,7 @@ def in_train_phase(x, alt, training=None):
training = learning_phase()
uses_learning_phase = True
else:
uses_learning_phase = False
uses_learning_phase = getattr(training, '_uses_learning_phase', False)

if training is 1 or training is True:
if callable(x):
Expand Down
22 changes: 21 additions & 1 deletion keras/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def __init__(self,
beta_constraint=None,
gamma_constraint=None,
**kwargs):
self._trainable = True
self._trainable_tensor = K.variable(1, dtype='float32', name='trainable')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this have a name scope?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ahundt Good thought. The challenge is self.name is not defined yet to do:

with K.name_scope(self.name):
    self._trainable_tensor = K.variable(1, dtype='float32', name='trainable')

The problem is that the parent constructor who initializes the name also sets the trainable property, so the _trainable_tensor must be defined before the constructor of parent fires. I blame the "Pythonic OO" model for even allowing this...

My original target was to minimize duplicate code and avoid too drastic changes in this PR. If you have any idea how I can put named scopes given the restrictions please let me know! :)

super(BatchNormalization, self).__init__(**kwargs)
self.supports_masking = True
self.axis = axis
Expand All @@ -88,6 +90,19 @@ def __init__(self,
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)

@property
def trainable(self):
# Use cached value to avoid unnecessary get_value() calls
return self._trainable

@trainable.setter
def trainable(self, trainable):
trainable = bool(trainable)
# Change when different to avoid unnecessary set_value() calls
if self._trainable != trainable:
self._trainable = trainable
K.set_value(self._trainable_tensor, 1 if trainable else 0)

def build(self, input_shape):
dim = input_shape[self.axis]
if dim is None:
Expand Down Expand Up @@ -171,9 +186,14 @@ def normalize_inference():
self.gamma,
epsilon=self.epsilon)

# If the learning phase is *static* and set to inference:
if training in {0, False}:
# If the learning phase is *static* and set to inference:
return normalize_inference()
elif training is None:
# If it's undefined then if trainable tensor is on respect learning phase else set to false
training = K.switch(self._trainable_tensor, K.cast(K.learning_phase(), 'float32'),
K.constant(0, dtype='float32'))
training._uses_learning_phase = True

# If the learning is either dynamic, or set to training:
normed_training, mean, variance = K.normalize_batch_in_training(
Expand Down
12 changes: 12 additions & 0 deletions tests/keras/layers/normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,24 @@ def get_model(bn_mean, bn_std):
model.set_weights([np.array([1.]), np.array([0.]),
np.array([bn_mean]), np.array([bn_std ** 2])])
return model

# Simulates training-mode with trainable layer. Should use mini-batch statistics.
K.set_learning_phase(1)
model = get_model(bn_mean, bn_std)
model.layers[1].trainable = True
model.compile(loss='mse', optimizer='rmsprop')
out = model.predict(input_4)
assert_allclose((input_4 - np.mean(input_4)) / np.std(input_4), out, atol=1e-3)

# In all other cases we should use the moving mean and variance from BN.
for lp, trainable in [(1, False), (0, True), (0, False)]:
K.set_learning_phase(lp)
model = get_model(bn_mean, bn_std)
model.layers[1].trainable = trainable
model.compile(loss='mse', optimizer='rmsprop')
out = model.predict(input_4)
assert_allclose((input_4 - bn_mean) / bn_std, out, atol=1e-3)


if __name__ == '__main__':
pytest.main([__file__])