Skip to content

Commit

Permalink
Support return_state parameter in ConvRecurrent2D (#7407)
Browse files Browse the repository at this point in the history
* Support return_state parameter in ConvRecurrent2D

Added support for the return_state=True parameter in the ConvRecurrent2D abstract class.

* Fix under-indent error

* Support for return_state parameter for ConvLSTM2D

* return_state unit test

* Fix PEP8 errors

* Fix PEP8 errors

* Update convolutional_recurrent.py

* Fix TF functionality

* Fix Python3.5 error

* Formatting and comment changes

* Adding channels_first

* runtime error

* remove batch size check
  • Loading branch information
ericwu09 authored and fchollet committed Jul 26, 2017
1 parent c5ff4c9 commit cafa286
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 17 deletions.
43 changes: 26 additions & 17 deletions keras/layers/convolutional_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,24 @@ def compute_output_shape(self, input_shape):
dilation=self.dilation_rate[1])
if self.return_sequences:
if self.data_format == 'channels_first':
return (input_shape[0], input_shape[1],
self.filters, rows, cols)
output_shape = (input_shape[0], input_shape[1],
self.filters, rows, cols)
elif self.data_format == 'channels_last':
return (input_shape[0], input_shape[1],
rows, cols, self.filters)
output_shape = (input_shape[0], input_shape[1],
rows, cols, self.filters)
else:
if self.data_format == 'channels_first':
return (input_shape[0], self.filters, rows, cols)
output_shape = (input_shape[0], self.filters, rows, cols)
elif self.data_format == 'channels_last':
return (input_shape[0], rows, cols, self.filters)
output_shape = (input_shape[0], rows, cols, self.filters)

if self.return_state:
if self.data_format == 'channels_first':
output_shape = [output_shape] + [(input_shape[0], self.filters, rows, cols) for _ in range(2)]
elif self.data_format == 'channels_last':
output_shape = [output_shape] + [(input_shape[0], rows, cols, self.filters) for _ in range(2)]

return output_shape

def get_config(self):
config = {'filters': self.filters,
Expand Down Expand Up @@ -429,24 +437,25 @@ def reset_states(self):
'input_shape must be provided '
'(including batch size). '
'Got input shape: ' + str(input_shape))

if self.return_sequences:
out_row, out_col, out_filter = output_shape[2:]
if self.return_state:
output_shape = output_shape[1]
else:
output_shape = (input_shape[0],) + output_shape[2:]
else:
out_row, out_col, out_filter = output_shape[1:]
if self.return_state:
output_shape = output_shape[1]
else:
output_shape = (input_shape[0],) + output_shape[1:]

if hasattr(self, 'states'):
K.set_value(self.states[0],
np.zeros((input_shape[0],
out_row, out_col, out_filter)))
np.zeros(output_shape))
K.set_value(self.states[1],
np.zeros((input_shape[0],
out_row, out_col, out_filter)))
np.zeros(output_shape))
else:
self.states = [K.zeros((input_shape[0],
out_row, out_col, out_filter)),
K.zeros((input_shape[0],
out_row, out_col, out_filter))]
self.states = [K.zeros(output_shape),
K.zeros(output_shape)]

def get_constants(self, inputs, training=None):
constants = []
Expand Down
19 changes: 19 additions & 0 deletions tests/keras/layers/convolutional_recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,25 @@ def test_convolutional_recurrent():
input_channel)

for return_sequences in [True, False]:

# test for return state:
input = Input(batch_shape=inputs.shape)
kwargs = {'data_format': data_format,
'return_sequences': return_sequences,
'return_state': True,
'stateful': True,
'filters': filters,
'kernel_size': (num_row, num_col),
'padding': 'valid'}
layer = convolutional_recurrent.ConvLSTM2D(**kwargs)
layer.build(inputs.shape)
outputs = layer(input)
output, states = outputs[0], outputs[1:]
assert len(states) == 2
model = Model(input, states[0])
state = model.predict(inputs)
np.testing.assert_allclose(K.eval(layer.states[0]), state, atol=1e-4)

# test for output shape:
output = layer_test(convolutional_recurrent.ConvLSTM2D,
kwargs={'data_format': data_format,
Expand Down

0 comments on commit cafa286

Please sign in to comment.