Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support unknown sequence lengths in LMUFeedforward #52

Merged
merged 5 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .nengobones.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,23 +81,23 @@ ci_scripts:
TF_FORCE_GPU_ALLOW_GROWTH: "true"
TF_VERSION: $TF_VERSION
remote_setup:
- micromamba install -y "$TF_VERSION"
- micromamba install -y "$TF_VERSION" cudnn=8.4
- template: remote-script
remote_script: docs
output_name: remote-docs
host: azure-docs
azure_name: nengo-dl-docs
azure_group: nengo-ci
remote_setup:
- micromamba install -y "$TF_VERSION"
- micromamba install -y "$TF_VERSION" cudnn=8.4
- template: remote-script
remote_script: examples
output_name: remote-examples
host: azure-examples
azure_name: nengo-dl-examples
azure_group: nengo-ci
remote_setup:
- micromamba install -y "$TF_VERSION"
- micromamba install -y "$TF_VERSION" cudnn=8.4
- template: deploy
wheel: true

Expand All @@ -108,6 +108,6 @@ pyproject_toml: {}
version_py:
type: semver
major: 0
minor: 5
patch: 1
minor: 6
patch: 0
release: false
18 changes: 17 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,27 @@ Release history
- Removed
- Fixed

0.5.1 (unreleased)
0.6.0 (unreleased)
==================

*Compatible with TensorFlow 2.4 - 2.11*

**Changed**

- ``LMUFeedforward`` can now be used with unknown sequence lengths, and ``LMU`` will
use ``LMUFeedforward`` for unknown sequence lengths (as long as the other conditions
are met, as before). (`#52`_)
- Allow ``input_to_hidden=True`` with ``hidden_cell=None``. This will act as a skip
connection. (`#52`_)
- Changed order of LMU states so that the LMU memory state always comes first, and
any states from the hidden cell come afterwards. (`#52`_)

**Fixed**

- Fixed errors when setting non-default dtype on LMU layers. (`#52`_)

.. _#52: https://github.com/nengo/keras-lmu/pull/52

0.5.0 (January 26, 2023)
========================

Expand Down
102 changes: 67 additions & 35 deletions keras_lmu/layers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Core classes for the KerasLMU package."""

import warnings

import numpy as np
import tensorflow as tf
from packaging import version
Expand Down Expand Up @@ -154,9 +156,10 @@ def __init__(
)

if self.hidden_cell is None:
for conn in ("hidden_to_memory", "input_to_hidden"):
if getattr(self, conn):
raise ValueError(f"{conn} must be False if hidden_cell is None")
if self.hidden_to_memory:
raise ValueError(
"hidden_to_memory must be False if hidden_cell is None"
)

self.hidden_output_size = self.memory_d * self.order
self.hidden_state_size = []
Expand All @@ -168,9 +171,9 @@ def __init__(
self.hidden_output_size = self.hidden_cell.units
self.hidden_state_size = [self.hidden_cell.units]

self.state_size = tf.nest.flatten(self.hidden_state_size) + [
self.memory_d * self.order
]
self.state_size = [self.memory_d * self.order] + tf.nest.flatten(
self.hidden_state_size
)
self.output_size = self.hidden_output_size

@property
Expand Down Expand Up @@ -228,7 +231,7 @@ def _cont2discrete_zoh(A, B):
"""

# combine A/B and pad to make square matrix
em_upper = tf.concat([A, B], axis=0)
em_upper = tf.concat([A, B], axis=0) # pylint: disable=no-value-for-parameter
em = tf.pad(em_upper, [(0, 0), (0, B.shape[0])])

# compute matrix exponential
Expand Down Expand Up @@ -326,13 +329,17 @@ def call(self, inputs, states, training=None): # noqa: C901

states = tf.nest.flatten(states)

# state for the hidden cell
h = states[:-1]
# state for the LMU memory
m = states[-1]
m = states[0]
# state for the hidden cell
h = states[1:]

# compute memory input
u = tf.concat((inputs, h[0]), axis=1) if self.hidden_to_memory else inputs
u = (
tf.concat((inputs, h[0]), axis=1) # pylint: disable=no-value-for-parameter
if self.hidden_to_memory
else inputs
)
if self.dropout > 0:
u *= self.get_dropout_mask_for_cell(u, training)
if self.kernel is not None:
Expand Down Expand Up @@ -381,7 +388,11 @@ def call(self, inputs, states, training=None): # noqa: C901
m = tf.reshape(m, (-1, self.memory_d * self.order))

# apply hidden cell
h_in = tf.concat((m, inputs), axis=1) if self.input_to_hidden else m
h_in = (
tf.concat((m, inputs), axis=1) # pylint: disable=no-value-for-parameter
if self.input_to_hidden
else m
)

if self.hidden_cell is None:
o = h_in
Expand All @@ -392,7 +403,7 @@ def call(self, inputs, states, training=None): # noqa: C901
o = self.hidden_cell(h_in, training=training)
h = [o]

return o, h + [m]
return o, [m] + h

def reset_dropout_mask(self):
"""Reset dropout mask for memory and hidden components."""
Expand Down Expand Up @@ -593,7 +604,7 @@ def theta(self):

return self._init_theta

def build(self, input_shapes):
def build(self, input_shape):
"""
Builds the layer.

Expand All @@ -604,12 +615,11 @@ def build(self, input_shapes):
with some additional bookkeeping.
"""

super().build(input_shapes)
super().build(input_shape)

if (
not self.hidden_to_memory
and not self.memory_to_memory
and input_shapes[1] is not None
and not self.trainable_theta
):
self.layer = LMUFeedforward(
Expand All @@ -626,6 +636,7 @@ def build(self, input_shapes):
bias_regularizer=self.bias_regularizer,
dropout=self.dropout,
return_sequences=self.return_sequences,
dtype=self.dtype,
)
else:
self.layer = tf.keras.layers.RNN(
Expand All @@ -648,11 +659,13 @@ def build(self, input_shapes):
bias_regularizer=self.bias_regularizer,
dropout=self.dropout,
recurrent_dropout=self.recurrent_dropout,
dtype=self.dtype,
),
return_sequences=self.return_sequences,
dtype=self.dtype,
)

self.layer.build(input_shapes)
self.layer.build(input_shape)

def call(self, inputs, training=None):
"""
Expand Down Expand Up @@ -790,9 +803,6 @@ def __init__(
):
super().__init__(**kwargs)

if input_to_hidden and hidden_cell is None:
raise ValueError("input_to_hidden must be False if hidden_cell is None")

if conv_mode not in ("fft", "raw"):
raise ValueError(f"Unrecognized conv mode '{conv_mode}'")

Expand Down Expand Up @@ -826,8 +836,10 @@ def __init__(
discretizer=discretizer,
kernel_initializer=None,
trainable=False,
dtype=self.dtype,
),
return_sequences=True,
dtype=self.dtype,
)
self.impulse_response = None
self.kernel = None
Expand All @@ -846,31 +858,37 @@ def build(self, input_shape):

super().build(input_shape)

seq_len = input_shape[1]
enc_d = input_shape[-1]

seq_len = input_shape[1]
if seq_len is None:
# TODO: we could dynamically run the impulse response for longer if
# needed using stateful=True
raise ValueError(
f"LMUFeedforward requires that the input shape's temporal axis be "
f"fully specified (got {seq_len})"
theta_factor = 5
warnings.warn(
f"Approximating unknown impulse length with {theta_factor}*theta; "
f"setting a fixed sequence length on inputs will remove the need for "
f"approximation"
)
impulse_len = self.theta * theta_factor
else:
impulse_len = seq_len

impulse = tf.reshape(tf.eye(seq_len, 1), (1, -1, 1))
impulse = tf.reshape(tf.eye(impulse_len, 1), (1, -1, 1))

self.impulse_response = tf.squeeze(
self.delay_layer(impulse, training=False), axis=0
)

if self.conv_mode == "fft":
self.impulse_response = tf.signal.rfft(
tf.transpose(self.impulse_response),
fft_length=[2 * seq_len],
self.impulse_response_fft = (
None
if seq_len is None
else tf.signal.rfft(
tf.transpose(self.impulse_response),
tbekolay marked this conversation as resolved.
Show resolved Hide resolved
fft_length=[2 * seq_len],
)
)
else:
if self.truncate_ir is not None:
assert self.impulse_response.shape == (seq_len, self.order)
assert self.impulse_response.shape == (impulse_len, self.order)

cumsum = tf.math.cumsum(
tf.math.abs(self.impulse_response), axis=0, reverse=True
Expand Down Expand Up @@ -949,13 +967,19 @@ def call(self, inputs, training=None):
m = self._raw_convolution(u)

# apply hidden cell
h_in = tf.concat((m, inputs), axis=-1) if self.input_to_hidden else m
h_in = (
tf.concat((m, inputs), axis=-1) # pylint: disable=no-value-for-parameter
if self.input_to_hidden
else m
)

if self.hidden_cell is None:
h = h_in if self.return_sequences else h_in[:, -1]
elif hasattr(self.hidden_cell, "state_size"):
h = tf.keras.layers.RNN(
self.hidden_cell, return_sequences=self.return_sequences
self.hidden_cell,
return_sequences=self.return_sequences,
dtype=self.dtype,
)(h_in, training=training)
else:
if not self.return_sequences:
Expand All @@ -977,9 +1001,17 @@ def _fft_convolution(self, u):
# Pad sequences to avoid circular convolution
# Perform the FFT
fft_input = tf.signal.rfft(u, fft_length=[2 * seq_len])
impulse_response = (
tf.signal.rfft(
tf.transpose(self.impulse_response[:seq_len]),
fft_length=[2 * seq_len],
)
if self.impulse_response_fft is None
else self.impulse_response_fft
)

# Elementwise product of FFT (with broadcasting)
result = tf.expand_dims(fft_input, axis=-2) * self.impulse_response
result = tf.expand_dims(fft_input, axis=-2) * impulse_response

# Inverse FFT
m = tf.signal.irfft(result, fft_length=[2 * seq_len])[..., :seq_len]
Expand Down
Loading