Skip to content

Commit

Permalink
force to use test mode in bn v7 (#216)
Browse files Browse the repository at this point in the history
* force to use test mode in bn v7

* fix backend test case

* add is_test to common

* try getenv

* Revert "try getenv"

This reverts commit 58d46f3.

* Revert "add is_test to common"

This reverts commit ae11c0c.
  • Loading branch information
fumihwh authored Jun 19, 2018
1 parent ce795a6 commit 8d19f27
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
3 changes: 2 additions & 1 deletion onnx_tf/handlers/backend/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def _common(cls, node, **kwargs):
running_variance = tf.reshape(tensor_dict[node.inputs[4]],
params_shape_broadcast)

if node.attrs.get("is_test", 0):
# from version 7, force to use test mode
if cls.SINCE_VERSION >= 7 or node.attrs.get("is_test", 0):
inputs = [x, running_mean, running_variance, bias, scale]
return [cls.make_tensor_from_onnx_node(node, inputs=inputs)]
spatial = node.attrs.get("spatial", 1) == 1
Expand Down
3 changes: 0 additions & 3 deletions test/backend/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,13 @@ def test_batch_normalization(self):
"BatchNormalization", ["X", "scale", "bias", "mean", "var"], ["Y"],
epsilon=0.001)
x_shape = [3, 5, 4, 2]
momentum = 0.9
param_shape = [5]
_param_shape = [1, 5, 1, 1]
x = self._get_rnd(x_shape, 0, 1)
m = self._get_rnd(param_shape, 0, 1)
_m = m.reshape(_param_shape)
_m = _m * momentum + np.mean(x, axis=0) * (1 - momentum)
v = self._get_rnd(param_shape, 0, 1)
_v = v.reshape(_param_shape)
_v = _v * momentum + np.var(x, axis=0) * (1 - momentum)
scale = self._get_rnd(param_shape, 0, 1)
_scale = scale.reshape(_param_shape)
bias = self._get_rnd(param_shape, 0, 1)
Expand Down

0 comments on commit 8d19f27

Please sign in to comment.