Skip to content

Commit

Permalink
Added dropout support in ConvLSTMCell. (#20089)
Browse files Browse the repository at this point in the history
Implemented dropout and recurrent dropout in `ConvLSTMCell` with the same approach as in Keras 2.

Fixes #20063
  • Loading branch information
hertschuh authored Aug 6, 2024
1 parent 445da52 commit 473a8ec
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 23 deletions.
41 changes: 24 additions & 17 deletions keras/src/layers/rnn/conv_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(

self.dropout = min(1.0, max(0.0, dropout))
self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout))
self.dropout_mask_count = 4
self.input_spec = InputSpec(ndim=rank + 2)
self.state_size = -1 # Custom, defined in methods

Expand Down Expand Up @@ -233,23 +234,29 @@ def call(self, inputs, states, training=False):
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state

# dp_mask = self.get_dropout_mask(inputs)
# rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1)

# if training and 0.0 < self.dropout < 1.0:
# inputs *= dp_mask
# if training and 0.0 < self.recurrent_dropout < 1.0:
# h_tm1 *= rec_dp_mask

inputs_i = inputs
inputs_f = inputs
inputs_c = inputs
inputs_o = inputs

h_tm1_i = h_tm1
h_tm1_f = h_tm1
h_tm1_c = h_tm1
h_tm1_o = h_tm1
if training and 0.0 < self.dropout < 1.0:
dp_mask = self.get_dropout_mask(inputs)
inputs_i = inputs * dp_mask[0]
inputs_f = inputs * dp_mask[1]
inputs_c = inputs * dp_mask[2]
inputs_o = inputs * dp_mask[3]
else:
inputs_i = inputs
inputs_f = inputs
inputs_c = inputs
inputs_o = inputs

if training and 0.0 < self.recurrent_dropout < 1.0:
rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1)
h_tm1_i = h_tm1 * rec_dp_mask[0]
h_tm1_f = h_tm1 * rec_dp_mask[1]
h_tm1_c = h_tm1 * rec_dp_mask[2]
h_tm1_o = h_tm1 * rec_dp_mask[3]
else:
h_tm1_i = h_tm1
h_tm1_f = h_tm1
h_tm1_c = h_tm1
h_tm1_o = h_tm1

(kernel_i, kernel_f, kernel_c, kernel_o) = ops.split(
self.kernel, 4, axis=self.rank + 1
Expand Down
25 changes: 19 additions & 6 deletions keras/src/layers/rnn/dropout_rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,36 @@ class DropoutRNNCell:
all incoming steps, so that the same mask is used for every step.
"""

def _create_dropout_mask(self, step_input, dropout_rate):
count = getattr(self, "dropout_mask_count", None)
ones = ops.ones_like(step_input)
if count is None:
return backend.random.dropout(
ones, rate=dropout_rate, seed=self.seed_generator
)
else:
return [
backend.random.dropout(
ones, rate=dropout_rate, seed=self.seed_generator
)
for _ in range(count)
]

def get_dropout_mask(self, step_input):
if not hasattr(self, "_dropout_mask"):
self._dropout_mask = None
if self._dropout_mask is None and self.dropout > 0:
ones = ops.ones_like(step_input)
self._dropout_mask = backend.random.dropout(
ones, rate=self.dropout, seed=self.seed_generator
self._dropout_mask = self._create_dropout_mask(
step_input, self.dropout
)
return self._dropout_mask

def get_recurrent_dropout_mask(self, step_input):
if not hasattr(self, "_recurrent_dropout_mask"):
self._recurrent_dropout_mask = None
if self._recurrent_dropout_mask is None and self.recurrent_dropout > 0:
ones = ops.ones_like(step_input)
self._recurrent_dropout_mask = backend.random.dropout(
ones, rate=self.recurrent_dropout, seed=self.seed_generator
self._recurrent_dropout_mask = self._create_dropout_mask(
step_input, self.recurrent_dropout
)
return self._recurrent_dropout_mask

Expand Down

0 comments on commit 473a8ec

Please sign in to comment.