Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align BN trainable behaviour with TF 2.0 #13892

Closed
wants to merge 13 commits into from
5 changes: 4 additions & 1 deletion keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def in_train_phase(x, alt, training=None):
training = learning_phase()
uses_learning_phase = True
else:
uses_learning_phase = False
uses_learning_phase = getattr(training, '_uses_learning_phase', False)

# CNTK currently don't support cond op, so here we use
# element_select approach as workaround. It may have
Expand Down Expand Up @@ -333,6 +333,9 @@ def is_sparse(tensor):


def int_shape(x):
if type(x) in {int, float}:
return ()

if hasattr(x, '_keras_shape'):
return x._keras_shape

Expand Down
2 changes: 1 addition & 1 deletion keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3105,7 +3105,7 @@ def in_train_phase(x, alt, training=None):
training = learning_phase()
uses_learning_phase = True
else:
uses_learning_phase = False
uses_learning_phase = getattr(training, '_uses_learning_phase', False)

if training is 1 or training is True:
if callable(x):
Expand Down
2 changes: 1 addition & 1 deletion keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,7 +1648,7 @@ def in_train_phase(x, alt, training=None):
training = learning_phase()
uses_learning_phase = True
else:
uses_learning_phase = False
uses_learning_phase = getattr(training, '_uses_learning_phase', False)

if training is 1 or training is True:
if callable(x):
Expand Down
24 changes: 23 additions & 1 deletion keras/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def __init__(self,
beta_constraint=None,
gamma_constraint=None,
**kwargs):
self._trainable = True
self._trainable_tensor = K.variable(1, dtype='float32', name='trainable')
super(BatchNormalization, self).__init__(**kwargs)
self.supports_masking = True
self.axis = axis
Expand All @@ -90,6 +92,19 @@ def __init__(self,
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)

@property
def trainable(self):
# Use cached value to avoid unnecessary get_value() calls
return self._trainable

@trainable.setter
def trainable(self, trainable):
trainable = bool(trainable)
# Change when different to avoid unnecessary set_value() calls
if self._trainable != trainable:
self._trainable = trainable
K.set_value(self._trainable_tensor, 1 if trainable else 0)

def build(self, input_shape):
dim = input_shape[self.axis]
if dim is None:
Expand Down Expand Up @@ -175,9 +190,16 @@ def normalize_inference():
axis=self.axis,
epsilon=self.epsilon)

# If the learning phase is *static* and set to inference:
if training in {0, False}:
# If the learning phase is *static* and set to inference:
return normalize_inference()
elif training is None:
# If it's undefined then if trainable tensor is on then
# respect learning phase else set to false
training = K.switch(self._trainable_tensor,
K.cast(K.learning_phase(), 'float32'),
K.constant(0, dtype='float32'))
training._uses_learning_phase = True

# If the learning is either dynamic, or set to training:
normed_training, mean, variance = K.normalize_batch_in_training(
Expand Down
10 changes: 10 additions & 0 deletions tests/keras/layers/normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,20 @@ def get_model(bn_mean, bn_std):
# Simulates training-mode with trainable layer. Should use mini-batch statistics.
K.set_learning_phase(1)
model = get_model(bn_mean, bn_std)
model.layers[1].trainable = True
model.compile(loss='mse', optimizer='rmsprop')
out = model.predict(input_4)
assert_allclose((input_4 - np.mean(input_4)) / np.std(input_4), out, atol=1e-3)

# In all other cases we should use the moving mean and variance from BN.
for lp, trainable in [(1, False), (0, True), (0, False)]:
K.set_learning_phase(lp)
model = get_model(bn_mean, bn_std)
model.layers[1].trainable = trainable
model.compile(loss='mse', optimizer='rmsprop')
out = model.predict(input_4)
assert_allclose((input_4 - bn_mean) / bn_std, out, atol=1e-3)


if __name__ == '__main__':
pytest.main([__file__])