Skip to content

Commit

Permalink
Fix convlstm correctness.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed May 19, 2023
1 parent 4dd75e0 commit 2cae421
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 69 deletions.
32 changes: 16 additions & 16 deletions keras_core/layers/rnn/conv_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
data_format=None,
dilation_rate=1,
activation="tanh",
recurrent_activation="hard_sigmoid",
recurrent_activation="sigmoid",
use_bias=True,
kernel_initializer="glorot_uniform",
recurrent_initializer="orthogonal",
Expand Down Expand Up @@ -153,30 +153,30 @@ def __init__(
self.input_spec = InputSpec(ndim=rank + 2)
self.state_size = -1 # Custom, defined in methods

def build(self, input_shape):
def build(self, inputs_shape, states_shape=None):
if self.data_format == "channels_first":
channel_axis = 1
self.spatial_dims = input_shape[2:]
self.spatial_dims = inputs_shape[2:]
else:
channel_axis = -1
self.spatial_dims = input_shape[1:-1]
self.spatial_dims = inputs_shape[1:-1]
if None in self.spatial_dims:
raise ValueError(
"ConvLSTM layers only support static "
"input shapes for the spatial dimension. "
f"Received invalid input shape: input_shape={input_shape}"
f"Received invalid input shape: input_shape={inputs_shape}"
)
if input_shape[channel_axis] is None:
if inputs_shape[channel_axis] is None:
raise ValueError(
"The channel dimension of the inputs (last axis) should be "
"defined. Found None. Full input shape received: "
f"input_shape={input_shape}"
f"input_shape={inputs_shape}"
)
self.input_spec = InputSpec(
ndim=self.rank + 3, shape=(None,) + input_shape[1:]
ndim=self.rank + 3, shape=(None,) + inputs_shape[1:]
)

input_dim = input_shape[channel_axis]
input_dim = inputs_shape[channel_axis]
self.input_dim = input_dim
self.kernel_shape = self.kernel_size + (input_dim, self.filters * 4)
recurrent_kernel_shape = self.kernel_size + (
Expand Down Expand Up @@ -234,13 +234,13 @@ def call(self, inputs, states, training=False):
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state

dp_mask = self.get_dropout_mask(inputs)
rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1)
# dp_mask = self.get_dropout_mask(inputs)
# rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1)

if training and 0.0 < self.dropout < 1.0:
inputs *= dp_mask
if training and 0.0 < self.recurrent_dropout < 1.0:
h_tm1 *= rec_dp_mask
# if training and 0.0 < self.dropout < 1.0:
# inputs *= dp_mask
# if training and 0.0 < self.recurrent_dropout < 1.0:
# h_tm1 *= rec_dp_mask

inputs_i = inputs
inputs_f = inputs
Expand Down Expand Up @@ -466,7 +466,7 @@ def __init__(
data_format=None,
dilation_rate=1,
activation="tanh",
recurrent_activation="hard_sigmoid",
recurrent_activation="sigmoid",
use_bias=True,
kernel_initializer="glorot_uniform",
recurrent_initializer="orthogonal",
Expand Down
2 changes: 1 addition & 1 deletion keras_core/layers/rnn/conv_lstm1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(
data_format=None,
dilation_rate=1,
activation="tanh",
recurrent_activation="hard_sigmoid",
recurrent_activation="sigmoid",
use_bias=True,
kernel_initializer="glorot_uniform",
recurrent_initializer="orthogonal",
Expand Down
44 changes: 22 additions & 22 deletions keras_core/layers/rnn/conv_lstm1d_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import numpy as np

from keras_core import initializers
from keras_core import layers
from keras_core import testing

Expand Down Expand Up @@ -43,25 +46,22 @@ def test_basics(self):
supports_masking=True,
)

# TODO: correctness testing

# def test_correctness(self):
# sequence = np.arange(120).reshape((2, 3, 4, 5)).astype("float32")
# layer = layers.ConvLSTM1D(
# filters=2,
# kernel_size=3,
# kernel_initializer=initializers.Constant(0.001),
# recurrent_initializer=initializers.Constant(0.0),
# bias_initializer=initializers.Constant(0.3),
# use_bias=False,
# )
# output = layer(sequence)
# self.assertAllClose(
# np.array(
# [
# [[0.49877906, 0.49877906], [0.5447451, 0.5447451]],
# [[0.94260275, 0.94260275], [0.95974874, 0.95974874]],
# ]
# ),
# output,
# )
def test_correctness(self):
sequence = np.arange(120).reshape((2, 3, 4, 5)).astype("float32") / 10
layer = layers.ConvLSTM1D(
filters=2,
kernel_size=3,
kernel_initializer=initializers.Constant(0.01),
recurrent_initializer=initializers.Constant(0.02),
bias_initializer=initializers.Constant(0.03),
)
output = layer(sequence)
self.assertAllClose(
np.array(
[
[[0.40807986, 0.40807986], [0.46421072, 0.46421072]],
[[0.80933154, 0.80933154], [0.8233646, 0.8233646]],
]
),
output,
)
2 changes: 1 addition & 1 deletion keras_core/layers/rnn/conv_lstm2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(
data_format=None,
dilation_rate=1,
activation="tanh",
recurrent_activation="hard_sigmoid",
recurrent_activation="sigmoid",
use_bias=True,
kernel_initializer="glorot_uniform",
recurrent_initializer="orthogonal",
Expand Down
57 changes: 30 additions & 27 deletions keras_core/layers/rnn/conv_lstm2d_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import numpy as np

from keras_core import initializers
from keras_core import layers
from keras_core import testing

Expand Down Expand Up @@ -43,30 +46,30 @@ def test_basics(self):
supports_masking=True,
)

# TODO: correctness testing

# def test_correctness(self):
# sequence = np.arange(480).reshape((2, 3, 4, 4, 5)).astype("float32")
# layer = layers.ConvLSTM2D(
# filters=2,
# kernel_size=3,
# kernel_initializer=initializers.Constant(0.0001),
# recurrent_initializer=initializers.Constant(0.01),
# bias_initializer=initializers.Constant(0.01),
# )
# output = layer(sequence)
# self.assertAllClose(
# np.array(
# [
# [
# [[0.4320268, 0.4320268], [0.4475501, 0.4475501]],
# [[0.49229687, 0.49229687], [0.50656533, 0.50656533]],
# ],
# [
# [[0.8781725, 0.8781725], [0.88340145, 0.88340145]],
# [[0.8988858, 0.8988858], [0.9039862, 0.9039862]],
# ],
# ]
# ),
# output,
# )
def test_correctness(self):
sequence = (
np.arange(480).reshape((2, 3, 4, 4, 5)).astype("float32") / 100
)
layer = layers.ConvLSTM2D(
filters=2,
kernel_size=3,
kernel_initializer=initializers.Constant(0.01),
recurrent_initializer=initializers.Constant(0.02),
bias_initializer=initializers.Constant(0.03),
)
output = layer(sequence)
self.assertAllClose(
np.array(
[
[
[[0.48694518, 0.48694518], [0.50237733, 0.50237733]],
[[0.5461202, 0.5461202], [0.5598283, 0.5598283]],
],
[
[[0.8661607, 0.8661607], [0.86909103, 0.86909103]],
[[0.8774414, 0.8774414], [0.8800861, 0.8800861]],
],
]
),
output,
)
2 changes: 1 addition & 1 deletion keras_core/layers/rnn/conv_lstm3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
data_format=None,
dilation_rate=1,
activation="tanh",
recurrent_activation="hard_sigmoid",
recurrent_activation="sigmoid",
use_bias=True,
kernel_initializer="glorot_uniform",
recurrent_initializer="orthogonal",
Expand Down
49 changes: 48 additions & 1 deletion keras_core/layers/rnn/conv_lstm3d_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import numpy as np

from keras_core import initializers
from keras_core import layers
from keras_core import testing

Expand Down Expand Up @@ -43,4 +46,48 @@ def test_basics(self):
supports_masking=True,
)

# TODO: correctness testing
def test_correctness(self):
sequence = (
np.arange(1920).reshape((2, 3, 4, 4, 4, 5)).astype("float32") / 100
)
layer = layers.ConvLSTM3D(
filters=2,
kernel_size=3,
kernel_initializer=initializers.Constant(0.01),
recurrent_initializer=initializers.Constant(0.02),
bias_initializer=initializers.Constant(0.03),
)
output = layer(sequence)
self.assertAllClose(
np.array(
[
[
[
[
[0.99149036, 0.99149036],
[0.99180907, 0.99180907],
],
[[0.99258363, 0.99258363], [0.9927925, 0.9927925]],
],
[
[
[0.99413764, 0.99413764],
[0.99420583, 0.99420583],
],
[[0.9943788, 0.9943788], [0.9944278, 0.9944278]],
],
],
[
[
[[0.9950547, 0.9950547], [0.9950547, 0.9950547]],
[[0.9950547, 0.9950547], [0.9950547, 0.9950547]],
],
[
[[0.9950547, 0.9950547], [0.9950547, 0.9950547]],
[[0.9950547, 0.9950547], [0.9950547, 0.9950547]],
],
],
]
),
output,
)
47 changes: 47 additions & 0 deletions keras_core/layers/rnn/conv_lstm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np

from keras_core import initializers
from keras_core import testing
from keras_core.layers.rnn.conv_lstm import ConvLSTM
from keras_core.layers.rnn.conv_lstm import ConvLSTMCell


class ConvLSTMCellTest(testing.TestCase):
def test_correctness(self):
x = np.arange(150).reshape((2, 5, 5, 3)).astype("float32") / 10
s1 = np.arange(200).reshape((2, 5, 5, 4)).astype("float32") / 10
s2 = np.arange(200).reshape((2, 5, 5, 4)).astype("float32") / 10

layer = ConvLSTMCell(
rank=2,
filters=4,
kernel_size=3,
padding="same",
kernel_initializer=initializers.Constant(0.01),
recurrent_initializer=initializers.Constant(0.02),
)
output = layer(x, [s1, s2])
checksum_0 = np.sum(output[0])
self.assertAllClose(checksum_0, 188.89502)
checksum_1 = np.sum(output[1][0])
self.assertAllClose(checksum_1, 188.89502)
checksum_2 = np.sum(output[1][1])
self.assertAllClose(checksum_2, 2170.444)


class ConvLSTMTest(testing.TestCase):
def test_correctness(self):
x = np.arange(450).reshape((2, 3, 5, 5, 3)).astype("float32") / 100
s1 = np.arange(200).reshape((2, 5, 5, 4)).astype("float32") / 100
s2 = np.arange(200).reshape((2, 5, 5, 4)).astype("float32") / 100

layer = ConvLSTM(
rank=2,
filters=4,
kernel_size=3,
padding="same",
kernel_initializer=initializers.Constant(0.01),
recurrent_initializer=initializers.Constant(0.02),
)
output = layer(x, initial_state=[s1, s2])
self.assertAllClose(np.sum(output), 119.812454)

0 comments on commit 2cae421

Please sign in to comment.