Skip to content

Commit

Permalink
Swap order of LMU states
Browse files Browse the repository at this point in the history
Having the memory state come first is more intuitive, as it is
both always present and comes first in the computational flow.
  • Loading branch information
drasmuss authored and tbekolay committed May 5, 2023
1 parent 41fdc58 commit a818733
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 11 deletions.
4 changes: 2 additions & 2 deletions .nengobones.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,6 @@ pyproject_toml: {}
version_py:
type: semver
major: 0
minor: 5
patch: 1
minor: 6
patch: 0
release: false
4 changes: 3 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Release history
- Removed
- Fixed
0.5.1 (unreleased)
0.6.0 (unreleased)
==================

*Compatible with TensorFlow 2.4 - 2.11*
Expand All @@ -31,6 +31,8 @@ Release history
are met, as before). (`#52`_)
- Allow ``input_to_hidden=True`` with ``hidden_cell=None``. This will act as a skip
connection. (`#52`_)
- Changed order of LMU states so that the LMU memory state always comes first, and
any states from the hidden cell come afterwards. (`#52`_)

**Fixed**

Expand Down
14 changes: 7 additions & 7 deletions keras_lmu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ def __init__(
self.hidden_output_size = self.hidden_cell.units
self.hidden_state_size = [self.hidden_cell.units]

self.state_size = tf.nest.flatten(self.hidden_state_size) + [
self.memory_d * self.order
]
self.state_size = [self.memory_d * self.order] + tf.nest.flatten(
self.hidden_state_size
)
self.output_size = self.hidden_output_size

@property
Expand Down Expand Up @@ -329,10 +329,10 @@ def call(self, inputs, states, training=None): # noqa: C901

states = tf.nest.flatten(states)

# state for the hidden cell
h = states[:-1]
# state for the LMU memory
m = states[-1]
m = states[0]
# state for the hidden cell
h = states[1:]

# compute memory input
u = (
Expand Down Expand Up @@ -403,7 +403,7 @@ def call(self, inputs, states, training=None): # noqa: C901
o = self.hidden_cell(h_in, training=training)
h = [o]

return o, h + [m]
return o, [m] + h

def reset_dropout_mask(self):
"""Reset dropout mask for memory and hidden components."""
Expand Down
2 changes: 1 addition & 1 deletion keras_lmu/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
tagged with the version.
"""

version_info = (0, 5, 1)
version_info = (0, 6, 0)

name = "keras-lmu"
dev = 0
Expand Down

0 comments on commit a818733

Please sign in to comment.