Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chenta/cntk bn #9952

Merged
merged 4 commits into from
Apr 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@

# A learning phase is a bool tensor used to run Keras models in
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
_LEARNING_PHASE = C.constant(shape=(), dtype=np.float32, value=1.0, name='_keras_learning_phase')
# LEARNING_PHASE_PLACEHOLDER is the placeholder for dynamic learning phase
_LEARNING_PHASE_PLACEHOLDER = C.constant(shape=(), dtype=np.float32, value=1.0, name='_keras_learning_phase')
# static learning phase flag, if it is not 0 or 1, we will go with dynamic learning phase tensor.
_LEARNING_PHASE = -1
_UID_PREFIXES = defaultdict(int)

# cntk doesn't support gradient as symbolic op, to hook up with keras model,
Expand All @@ -49,8 +52,8 @@ def get_uid(prefix=''):


def learning_phase():
# False = test, True = train
return _LEARNING_PHASE
# If _LEARNING_PHASE is not 0 or 1, return dynamic learning phase tensor
return _LEARNING_PHASE if _LEARNING_PHASE in {0, 1} else _LEARNING_PHASE_PLACEHOLDER


def set_learning_phase(value):
Expand All @@ -59,8 +62,16 @@ def set_learning_phase(value):
raise ValueError('CNTK Backend: Set learning phase '
'with value %s is not supported, '
'expected 0 or 1.' % value)
v = np.asarray(value)
_LEARNING_PHASE.value = v
_LEARNING_PHASE = value


def clear_session():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this. Every backend should have a way to reset the state of Learning Phase even if it's not possible to clear the memory.

"""Reset learning phase flag for cntk backend.
"""
global _LEARNING_PHASE
global _LEARNING_PHASE_PLACEHOLDER
_LEARNING_PHASE = -1
_LEARNING_PHASE_PLACEHOLDER.value = np.asarray(1.0)


def in_train_phase(x, alt, training=None):
Expand All @@ -83,7 +94,11 @@ def in_train_phase(x, alt, training=None):
x._uses_learning_phase = uses_learning_phase
return x
else:
result = C.element_select(training, x, alt)
# if _LEARNING_PHASE is static
if isinstance(training, int) or isinstance(training, bool):
result = x if training == 1 or training is True else alt
else:
result = C.element_select(training, x, alt)
result._uses_learning_phase = uses_learning_phase
return result

Expand Down Expand Up @@ -1854,6 +1869,7 @@ def _is_input_shape_compatible(input, placeholder):
return True

def __call__(self, inputs):
global _LEARNING_PHASE_PLACEHOLDER
global _LEARNING_PHASE
assert isinstance(inputs, (list, tuple))
feed_dict = {}
Expand All @@ -1864,8 +1880,8 @@ def __call__(self, inputs):
value.dtype != np.float64):
value = value.astype(np.float32)

if tensor == _LEARNING_PHASE:
_LEARNING_PHASE.value = np.asarray(value)
if tensor == _LEARNING_PHASE_PLACEHOLDER:
_LEARNING_PHASE_PLACEHOLDER.value = np.asarray(value)
else:
# in current version cntk can't support input with variable
# length. Will support it in next release.
Expand Down Expand Up @@ -1911,7 +1927,8 @@ def __call__(self, inputs):
# "forward" method to let cntk know we want to evaluate them.from
# But the assign ops won't be executed under this mode, that's why
# we need this check.
if self.unrelated_updates is None and _LEARNING_PHASE.value == 1.0:
if (self.unrelated_updates is None and
(_LEARNING_PHASE_PLACEHOLDER.value == 1.0 or _LEARNING_PHASE == 1)):
_, output_values = self.metrics_func.forward(
input_dict,
self.metrics_func.outputs,
Expand Down
2 changes: 1 addition & 1 deletion keras/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def keras_test(func):
@six.wraps(func)
def wrapper(*args, **kwargs):
output = func(*args, **kwargs)
if K.backend() == 'tensorflow':
if K.backend() == 'tensorflow' or K.backend() == 'cntk':
K.clear_session()
return output
return wrapper
20 changes: 20 additions & 0 deletions tests/keras/layers/normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
input_1 = np.arange(10)
input_2 = np.zeros(10)
input_3 = np.ones((10))
input_4 = np.expand_dims(np.arange(10.), axis=1)
input_shapes = [np.ones((10, 10)), np.ones((10, 10, 10))]


Expand Down Expand Up @@ -223,5 +224,24 @@ def test_that_trainable_disables_updates():
assert_allclose(x1, x2, atol=1e-7)


@keras_test
def test_batchnorm_trainable():
bn_mean = 0.5
bn_std = 10.

def get_model(bn_mean, bn_std):
input = Input(shape=(1,))
x = normalization.BatchNormalization()(input)
model = Model(input, x)
model.set_weights([np.array([1.]), np.array([0.]),
np.array([bn_mean]), np.array([bn_std ** 2])])
return model
# Simulates training-mode with trainable layer. Should use mini-batch statistics.
K.set_learning_phase(1)
model = get_model(bn_mean, bn_std)
model.compile(loss='mse', optimizer='rmsprop')
out = model.predict(input_4)
assert_allclose((input_4 - np.mean(input_4)) / np.std(input_4), out, atol=1e-3)

if __name__ == '__main__':
pytest.main([__file__])