Skip to content

Commit

Permalink
fix compute_output_shape behavior (#2678)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhaopudark committed Feb 24, 2022
1 parent da14c3b commit 18c8367
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 6 deletions.
6 changes: 0 additions & 6 deletions tensorflow_addons/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,6 @@ def get_config(self):
base_config = super().get_config()
return {**base_config, **config}

def compute_output_shape(self, input_shape):
return input_shape

def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape):

group_shape = [tensor_input_shape[i] for i in range(len(input_shape))]
Expand Down Expand Up @@ -447,9 +444,6 @@ def call(self, inputs):
normalized_inputs = inputs * tf.math.rsqrt(nu2 + epsilon)
return self.gamma * normalized_inputs + self.beta

def compute_output_shape(self, input_shape):
return input_shape

def get_config(self):
config = {
"axis": self.axis,
Expand Down
66 changes: 66 additions & 0 deletions tensorflow_addons/layers/tests/normalizations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,52 @@ def test_groupnorm_convnet_no_center_no_scale():
)


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize("center", [True, False])
@pytest.mark.parametrize("scale", [True, False])
def test_group_norm_compute_output_shape(center, scale):

target_variables_len = [center, scale].count(True)
target_trainable_variables_len = [center, scale].count(True)
layer1 = GroupNormalization(groups=2, center=center, scale=scale)
layer1.build(input_shape=[8, 28, 28, 16]) # build()
assert len(layer1.variables) == target_variables_len
assert len(layer1.trainable_variables) == target_trainable_variables_len

layer2 = GroupNormalization(groups=2, center=center, scale=scale)
layer2.compute_output_shape(input_shape=[8, 28, 28, 16]) # compute_output_shape()
assert len(layer2.variables) == target_variables_len
assert len(layer2.trainable_variables) == target_trainable_variables_len

layer3 = GroupNormalization(groups=2, center=center, scale=scale)
layer3(tf.random.normal(shape=[8, 28, 28, 16])) # call()
assert len(layer3.variables) == target_variables_len
assert len(layer3.trainable_variables) == target_trainable_variables_len


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize("center", [True, False])
@pytest.mark.parametrize("scale", [True, False])
def test_instance_norm_compute_output_shape(center, scale):

target_variables_len = [center, scale].count(True)
target_trainable_variables_len = [center, scale].count(True)
layer1 = InstanceNormalization(groups=2, center=center, scale=scale)
layer1.build(input_shape=[8, 28, 28, 16]) # build()
assert len(layer1.variables) == target_variables_len
assert len(layer1.trainable_variables) == target_trainable_variables_len

layer2 = InstanceNormalization(groups=2, center=center, scale=scale)
layer2.compute_output_shape(input_shape=[8, 28, 28, 16]) # compute_output_shape()
assert len(layer2.variables) == target_variables_len
assert len(layer2.trainable_variables) == target_trainable_variables_len

layer3 = InstanceNormalization(groups=2, center=center, scale=scale)
layer3(tf.random.normal(shape=[8, 28, 28, 16])) # call()
assert len(layer3.variables) == target_variables_len
assert len(layer3.trainable_variables) == target_trainable_variables_len


def calculate_frn(
x, beta=0.2, gamma=1, eps=1e-6, learned_epsilon=False, dtype=np.float32
):
Expand Down Expand Up @@ -471,3 +517,23 @@ def test_filter_response_normalization_save(tmpdir):
model.save(filepath, save_format="h5")
filepath = str(tmpdir / "test")
model.save(filepath, save_format="tf")


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_filter_response_norm_compute_output_shape():
target_variables_len = 2
target_trainable_variables_len = 2
layer1 = FilterResponseNormalization()
layer1.build(input_shape=[8, 28, 28, 16]) # build()
assert len(layer1.variables) == target_variables_len
assert len(layer1.trainable_variables) == target_trainable_variables_len

layer2 = FilterResponseNormalization()
layer2.compute_output_shape(input_shape=[8, 28, 28, 16]) # compute_output_shape()
assert len(layer2.variables) == target_variables_len
assert len(layer2.trainable_variables) == target_trainable_variables_len

layer3 = FilterResponseNormalization()
layer3(tf.random.normal(shape=[8, 28, 28, 16])) # call()
assert len(layer3.variables) == target_variables_len
assert len(layer3.trainable_variables) == target_trainable_variables_len

0 comments on commit 18c8367

Please sign in to comment.