From 7e1b230f50dee06de9cd008d506201c367faa8e2 Mon Sep 17 00:00:00 2001 From: Daniel Rasmussen Date: Thu, 17 Jun 2021 19:03:54 -0300 Subject: [PATCH] squash! Add trainable theta and discretization options - Simplify A/B calculations - More efficient cont2discrete implementation - Rename train_theta to trainable_theta - Store A/B as constants - Remove A/B caching in zoh training - Add test to make sure that zoh and euler produce (approximately) the same output - Disable FFT if trainable_theta=True - Update trainable_theta tests --- CHANGES.rst | 8 +- keras_lmu/layers.py | 238 ++++++++++++++------------------- keras_lmu/tests/test_layers.py | 231 ++++++++++++++------------------ keras_lmu/version.py | 2 +- 4 files changed, 214 insertions(+), 265 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index d6ba7c89..8c9ab6e3 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -19,7 +19,7 @@ Release history - Removed - Fixed -0.3.2 (unreleased) +0.4.0 (unreleased) ================== **Added** @@ -30,6 +30,12 @@ Release history are satisfied (no hidden-to-memory or memory-to-memory connections, and the sequence length is not ``None``). (`#40`_) +**Changed** + +- The ``A`` and ``B`` matrices are now stored as constants instead of non-trainable + variables. This can improve the training/inference speed, but it means that saved + weights from previous versions will be incompatible. (`#40`_) + .. _#40: https://github.com/nengo/keras-lmu/pull/40 0.3.1 (November 16, 2020) diff --git a/keras_lmu/layers.py b/keras_lmu/layers.py index a92b223e..3248889e 100644 --- a/keras_lmu/layers.py +++ b/keras_lmu/layers.py @@ -4,6 +4,7 @@ import numpy as np import tensorflow as tf +from scipy.signal import cont2discrete from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin @@ -34,11 +35,11 @@ class to create a recurrent Keras layer to process the whole sequence. Calling step. If this value is smaller than the size of the input sequence, only that number of steps will be represented at the time of prediction, however the entire sequence will still be processed in order for information to be - projected to and from the hidden layer. If ``train_theta`` is enabled, then + projected to and from the hidden layer. If ``trainable_theta`` is enabled, then theta will be updated during the course of training. hidden_cell : ``tf.keras.layers.Layer`` Keras Layer/RNNCell implementing the hidden component. - train_theta : bool + trainable_theta : bool If True, theta is learnt over the course of training. Otherwise, it is kept constant. hidden_to_memory : bool @@ -51,13 +52,11 @@ class to create a recurrent Keras layer to process the whole sequence. Calling If True, connect the input directly to the hidden component (in addition to the connection from the memory component) (default False). discretizer : str - The method to use to discretize the A and B matrices of the LMU. Current - options are "zoh" (short for Zero Order Hold) and "euler". The - training with "zoh" might slower when compared to "euler". Preference of one - method over the other may vary according to the use case. Note that a - larger theta is needed when discretizing using "euler". At least a value - that is larger than "4*order" is recommended. Discretizing with "zoh" is - recommended if training time is not a constraint. + The method used to discretize the A and B matrices of the LMU. Current + options are "zoh" (short for Zero Order Hold) and "euler". + "zoh" is more accurate, but training will be slower than "euler" if + ``trainable_theta=True``. Note that a larger theta is needed when discretizing + using "euler" (a value that is larger than ``4*order`` is recommended). kernel_initializer : ``tf.initializers.Initializer`` Initializer for weights from input to memory/hidden component. If ``None``, no weights will be used, and the input size must match the memory/hidden size. @@ -84,7 +83,7 @@ def __init__( order, theta, hidden_cell, - train_theta=False, + trainable_theta=False, hidden_to_memory=False, memory_to_memory=False, input_to_hidden=False, @@ -101,7 +100,7 @@ def __init__( self.order = order self._init_theta = theta self.hidden_cell = hidden_cell - self.train_theta = train_theta + self.trainable_theta = trainable_theta self.hidden_to_memory = hidden_to_memory self.memory_to_memory = memory_to_memory self.input_to_hidden = input_to_hidden @@ -119,7 +118,7 @@ def __init__( if self.discretizer not in ("zoh", "euler"): raise ValueError( - f"discretizer must be 'zoh' or 'euler'. Got {self.discretizer}" + f"discretizer must be 'zoh' or 'euler' (got '{self.discretizer}')" ) if self.hidden_cell is None: @@ -142,69 +141,67 @@ def __init__( ] self.output_size = self.hidden_output_size - def _gen_constants(self): - """ - Generates constants for discretizing A and B matrices. - """ - self.CONST_Q = np.arange(self.order, dtype=np.float32) - self.CONST_R = (2 * self.CONST_Q + 1)[:, None] - self.CONST_j, self.CONST_i = np.meshgrid(self.CONST_Q, self.CONST_Q) - self.CONST_A = ( - np.where( - self.CONST_i < self.CONST_j, - -1, - (-1.0) ** (self.CONST_i - self.CONST_j + 1), + def _gen_AB(self): + """Generates A and B matrices.""" + + # compute analog A/B matrices + Q = np.arange(self.order, dtype=np.float64) + R = (2 * Q + 1)[:, None] + j, i = np.meshgrid(Q, Q) + A = np.where(i < j, -1, (-1.0) ** (i - j + 1)) * R + B = (-1.0) ** Q[:, None] * R + + # discretize matrices + if self.discretizer == "zoh": + if self.trainable_theta: + # save the un-discretized matrices for use in .call + self._base_A = tf.constant(A.T, dtype=tf.float32) + self._base_B = tf.constant(B.T, dtype=tf.float32) + + A, B, *_ = cont2discrete( + ( + A / self._init_theta, + B / self._init_theta, + np.ones((1, self.order)), + np.zeros((1,)), + ), + dt=1.0, + method="zoh", ) - * self.CONST_R - ) - self.CONST_B = (-1.0) ** self.CONST_Q[:, None] * self.CONST_R + else: + if not self.trainable_theta: + A = A / self._init_theta + np.eye(self.order) + B = B / self._init_theta - def _gen_AB(self): - """ - Generates discretized A and B matrices, which are then set as weights A and B - respectively + self.A = tf.constant(A.T, dtype=self.dtype) + self.B = tf.constant(B.T, dtype=self.dtype) + + @staticmethod + def _cont2discrete_zoh(A, B): """ + Function to discretize A and B matrices using Zero Order Hold method. - def _cont2discrete_zoh(A, B, dt=1.0): - """ - Function to discretize A and B matrices using Zero Order Hold method. - Functionally equivalent to scipy.signal.cont2discrete(method="zoh"). - """ - em_upper = tf.concat([A, B], axis=1) - em_lower = tf.concat( - [ - tf.zeros(shape=(B.shape[1], A.shape[0])), - tf.zeros(shape=(B.shape[1], B.shape[1])), - ], - axis=1, - ) - em = tf.concat([em_upper, em_lower], axis=0) + Functionally equivalent to + ``scipy.signal.cont2discrete((A.T, B.T, _, _), method="zoh", dt=1.0)`` + (but implemented in TensorFlow so that it is differentiable). - ms = tf.linalg.expm(dt * em) - ms = ms[: A.shape[0], :] + Note that this accepts and returns matrices that are transposed from the + standard linear system implementation (as that makes it easier to use in + `.call`). + """ - discrt_A = ms[:, 0 : A.shape[1]] - discrt_B = ms[:, A.shape[1] :] + # combine A/B and pad to make square matrix + em_upper = tf.concat([A, B], axis=0) + em = tf.pad(em_upper, [(0, 0), (0, B.shape[0])]) - return tf.transpose(discrt_A), tf.transpose(discrt_B) + # compute matrix exponential + ms = tf.linalg.expm(em) - # A and B cannot be decoupled from theta when using zoh. thus generated - # using theta here and set as weights in ``build`` and ``call``. - if self.discretizer == "zoh": - self._A, self._B = _cont2discrete_zoh( - self.CONST_A * self.theta_inv, - self.CONST_B * self.theta_inv, - dt=1.0, - ) - # A and B can be decoupled from theta when using euler. hence, only - # generated once in ``build``, regardless if theta is trainable or not. - # if theta is trainable, A and B weights are set without theta and - # division by theta, according to euler discretization formula, is done - # in ``call``. - elif self.discretizer == "euler": - self._A, self._B = self.CONST_A.T, self.CONST_B.T - if not self.train_theta: - self._A, self._B = self._A * self.theta_inv, self._B * self.theta_inv + # slice A/B back out of combined matrix + discrt_A = ms[: A.shape[0], : A.shape[1]] + discrt_B = ms[A.shape[0] :, : A.shape[1] :] + + return discrt_A, discrt_B def build(self, input_shape): """ @@ -239,16 +236,15 @@ def build(self, input_shape): # when using euler, 1/theta results in better gradients for the memory # update since you are multiplying 1/theta, as compared to dividing theta - if self.train_theta: + if self.trainable_theta: self.theta_inv = self.add_weight( - name="theta_kernel", - shape=(1), + name="theta_inv", + shape=(), initializer=tf.initializers.constant(1 / self._init_theta), - trainable=True, constraint=tf.keras.constraints.NonNeg(), ) else: - self.theta_inv = 1 / self._init_theta + self.theta_inv = tf.constant(1 / self._init_theta, dtype=self.dtype) if self.memory_to_memory: self.recurrent_kernel = self.add_weight( @@ -266,27 +262,8 @@ def build(self, input_shape): with tf.name_scope(self.hidden_cell.name): self.hidden_cell.build((input_shape[0], hidden_input_d)) - # set initial A and B weights - self._gen_constants() # generate constants - self._gen_AB() # generate initital values of A and B - - self.A = self.add_weight( - name="A", - shape=(self.order, self.order), - initializer=tf.initializers.constant( - self._A.numpy() if self.discretizer == "zoh" else self._A - ), - trainable=False, - ) - - self.B = self.add_weight( - name="B", - shape=(1, self.order), # system is SISO - initializer=tf.initializers.constant( - self._B.numpy() if self.discretizer == "zoh" else self._B - ), - trainable=False, - ) + # generate A and B matrices + self._gen_AB() def call(self, inputs, states, training=None): """ @@ -330,21 +307,23 @@ def call(self, inputs, states, training=None): u = tf.expand_dims(u, -1) # update memory - if self.discretizer == "zoh" and training and self.train_theta: - # ``theta`` cannot be decoupled. generate new A and B matrices - # and assign them as weights. - self._gen_AB() - self.A.assign(self._A) - self.B.assign(self._B) - A, B = self._A, self._B + if self.discretizer == "zoh" and self.trainable_theta: + # apply updated theta and re-discretize + A, B = LMUCell._cont2discrete_zoh( + self._base_A * self.theta_inv, self._base_B * self.theta_inv + ) else: A, B = self.A, self.B + _m = tf.matmul(m, A) + tf.matmul(u, B) - if self.discretizer == "euler": - # if training theta, ``theta_inv`` can be decopled when using euler - # by dividing original A and B weights if training theta - m = m + ((self.theta_inv if self.train_theta else 1) * _m) + if self.discretizer == "euler" and self.trainable_theta: + # apply updated theta. this is the same as scaling A/B by theta, but it's + # more efficient to do it this way. + # note that when computing this way the A matrix does not + # include the identity matrix along the diagonal (since we don't want to + # scale that part by theta), which is why we do += instead of = + m += _m * self.theta_inv else: m = _m @@ -387,7 +366,7 @@ def get_config(self): order=self.order, theta=self._init_theta, hidden_cell=tf.keras.layers.serialize(self.hidden_cell), - train_theta=self.train_theta, + trainable_theta=self.trainable_theta, hidden_to_memory=self.hidden_to_memory, memory_to_memory=self.memory_to_memory, input_to_hidden=self.input_to_hidden, @@ -439,11 +418,11 @@ class LMU(tf.keras.layers.Layer): step. If this value is smaller than the size of the input sequence, only that number of steps will be represented at the time of prediction, however the entire sequence will still be processed in order for information to be - projected to and from the hidden layer. If ``train_theta`` is enabled, then + projected to and from the hidden layer. If ``trainable_theta`` is enabled, then theta will be updated during the course of training. hidden_cell : ``tf.keras.layers.Layer`` Keras Layer/RNNCell implementing the hidden component. - train_theta : bool + trainable_theta : bool If True, theta is learnt over the course of training. Otherwise, it is kept constant. hidden_to_memory : bool @@ -456,13 +435,11 @@ class LMU(tf.keras.layers.Layer): If True, connect the input directly to the hidden component (in addition to the connection from the memory component) (default False). discretizer : str - The method to use to discretize the A and B matrices of the LMU. Current - options are "zoh" (short for Zero Order Hold) and "euler". The - training with "zoh" might slower when compared to "euler". Preference of one - method over the other may vary according to the use case. Note that a - larger theta is needed when discretizing using "euler". At least a value - that is larger than "4*order" is recommended. Discretizing with "zoh" is - recommended if training time is not a constraint. + The method used to discretize the A and B matrices of the LMU. Current + options are "zoh" (short for Zero Order Hold) and "euler". + "zoh" is more accurate, but training will be slower than "euler" if + ``trainable_theta=True``. Note that a larger theta is needed when discretizing + using "euler" (a value that is larger than ``4*order`` is recommended). kernel_initializer : ``tf.initializers.Initializer`` Initializer for weights from input to memory/hidden component. If ``None``, no weights will be used, and the input size must match the memory/hidden size. @@ -492,7 +469,7 @@ def __init__( order, theta, hidden_cell, - train_theta=False, + trainable_theta=False, hidden_to_memory=False, memory_to_memory=False, input_to_hidden=False, @@ -511,7 +488,7 @@ def __init__( self.order = order self._init_theta = theta self.hidden_cell = hidden_cell - self.train_theta = train_theta + self.trainable_theta = trainable_theta self.hidden_to_memory = hidden_to_memory self.memory_to_memory = memory_to_memory self.input_to_hidden = input_to_hidden @@ -540,13 +517,13 @@ def build(self, input_shapes): not self.hidden_to_memory and not self.memory_to_memory and input_shapes[1] is not None + and not self.trainable_theta ): self.layer = LMUFFT( memory_d=self.memory_d, order=self.order, theta=self._init_theta, hidden_cell=self.hidden_cell, - train_theta=self.train_theta, input_to_hidden=self.input_to_hidden, discretizer=self.discretizer, kernel_initializer=self.kernel_initializer, @@ -560,7 +537,7 @@ def build(self, input_shapes): order=self.order, theta=self._init_theta, hidden_cell=self.hidden_cell, - train_theta=self.train_theta, + trainable_theta=self.trainable_theta, hidden_to_memory=self.hidden_to_memory, memory_to_memory=self.memory_to_memory, input_to_hidden=self.input_to_hidden, @@ -598,7 +575,7 @@ def get_config(self): order=self.order, theta=self._init_theta, hidden_cell=tf.keras.layers.serialize(self.hidden_cell), - train_theta=self.train_theta, + trainable_theta=self.trainable_theta, hidden_to_memory=self.hidden_to_memory, memory_to_memory=self.memory_to_memory, input_to_hidden=self.input_to_hidden, @@ -646,24 +623,19 @@ class LMUFFT(tf.keras.layers.Layer): step. If this value is smaller than the size of the input sequence, only that number of steps will be represented at the time of prediction, however the entire sequence will still be processed in order for information to be - projected to and from the hidden layer. If ``train_theta`` is enabled, then + projected to and from the hidden layer. If ``trainable_theta`` is enabled, then theta will be updated during the course of training. hidden_cell : ``tf.keras.layers.Layer`` Keras Layer implementing the hidden component. - train_theta : bool - If True, theta is learnt over the course of training. Otherwise, it is kept - constant. input_to_hidden : bool If True, connect the input directly to the hidden component (in addition to the connection from the memory component) (default False). discretizer : str - The method to use to discretize the A and B matrices of the LMU. Current - options are "zoh" (short for Zero Order Hold) and "euler". The - training with "zoh" might slower when compared to "euler". Preference of one - method over the other may vary according to the use case. Note that a - larger theta is needed when discretizing using "euler". At least a value - that is larger than "4*order" is recommended. Discretizing with "zoh" is - recommended if training time is not a constraint. + The method used to discretize the A and B matrices of the LMU. Current + options are "zoh" (short for Zero Order Hold) and "euler". + "zoh" is more accurate, but training will be slower than "euler" if + ``trainable_theta=True``. Note that a larger theta is needed when discretizing + using "euler" (a value that is larger than ``4*order`` is recommended). kernel_initializer : ``tf.initializers.Initializer`` Initializer for weights from input to memory/hidden component. If ``None``, no weights will be used, and the input size must match the memory/hidden size. @@ -680,7 +652,6 @@ def __init__( order, theta, hidden_cell, - train_theta=False, input_to_hidden=False, discretizer="zoh", kernel_initializer="glorot_uniform", @@ -697,7 +668,6 @@ def __init__( self.order = order self._init_theta = theta self.hidden_cell = hidden_cell - self.train_theta = train_theta self.input_to_hidden = input_to_hidden self.discretizer = discretizer self.kernel_initializer = kernel_initializer @@ -711,13 +681,12 @@ def __init__( order=order, theta=theta, hidden_cell=None, - train_theta=train_theta, + trainable_theta=False, input_to_hidden=False, hidden_to_memory=False, memory_to_memory=False, discretizer=discretizer, kernel_initializer=None, - dropout=0, trainable=False, ), return_sequences=True, @@ -851,7 +820,6 @@ def get_config(self): order=self.order, theta=self._init_theta, hidden_cell=tf.keras.layers.serialize(self.hidden_cell), - train_theta=self.train_theta, input_to_hidden=self.input_to_hidden, discretizer=self.discretizer, kernel_initializer=self.kernel_initializer, diff --git a/keras_lmu/tests/test_layers.py b/keras_lmu/tests/test_layers.py index df3c36f8..a6aa35b7 100644 --- a/keras_lmu/tests/test_layers.py +++ b/keras_lmu/tests/test_layers.py @@ -62,8 +62,8 @@ def test_multivariate_lmu(rng, discretizer): for i in range(memory_d): assert np.allclose( - results[0][..., i * order : (i + 1) * order], results[i + 1], atol=1e-5 - ), max(abs(results[0][..., i * order : (i + 1) * order] - results[i + 1])) + results[0][..., i * order : (i + 1) * order], results[i + 1], atol=2e-5 + ), np.max(abs(results[0][..., i * order : (i + 1) * order] - results[i + 1])) @pytest.mark.parametrize("has_input_kernel", (True, False)) @@ -103,14 +103,15 @@ def test_layer_vs_cell(rng, has_input_kernel, fft, discretizer): ): assert np.allclose(w0.numpy(), w1.numpy()) - atol = 2e-6 if fft else 1e-8 - assert np.allclose(cell_out, lmu_cell(inp), atol=atol) - assert np.allclose(cell_out, layer_out, atol=atol) + assert np.allclose(cell_out, lmu_cell(inp)) + assert np.allclose(cell_out, layer_out, atol=3e-6 if fft else 1e-8), np.max( + np.abs(cell_out - layer_out) + ) @pytest.mark.parametrize("discretizer", ("zoh", "euler")) -@pytest.mark.parametrize("train_theta", (True, False)) -def test_save_load_weights(rng, tmp_path, discretizer, train_theta): +@pytest.mark.parametrize("trainable_theta", (True, False)) +def test_save_load_weights(rng, tmp_path, discretizer, trainable_theta): memory_d = 4 order = 12 n_steps = 10 @@ -125,7 +126,7 @@ def test_save_load_weights(rng, tmp_path, discretizer, train_theta): n_steps, tf.keras.layers.SimpleRNNCell(units=64), discretizer=discretizer, - train_theta=train_theta, + trainable_theta=trainable_theta, return_sequences=True, )(inp) model0 = tf.keras.Model(inp, lmu0) @@ -137,7 +138,7 @@ def test_save_load_weights(rng, tmp_path, discretizer, train_theta): n_steps, tf.keras.layers.SimpleRNNCell(units=64), discretizer=discretizer, - train_theta=train_theta, + trainable_theta=trainable_theta, return_sequences=True, )(inp) model1 = tf.keras.Model(inp, lmu1) @@ -153,9 +154,12 @@ def test_save_load_weights(rng, tmp_path, discretizer, train_theta): @pytest.mark.parametrize("discretizer", ("zoh", "euler")) -@pytest.mark.parametrize("train_theta", (True, False)) +@pytest.mark.parametrize("trainable_theta", (True, False)) @pytest.mark.parametrize("mode", ("cell", "lmu", "fft")) -def test_save_load_serialization(mode, tmp_path, train_theta, discretizer): +def test_save_load_serialization(mode, tmp_path, trainable_theta, discretizer): + if mode == "fft" and trainable_theta: + pytest.skip("FFT does not support trainable theta") + inp = tf.keras.Input((10 if mode == "fft" else None, 32)) if mode == "cell": out = tf.keras.layers.RNN( @@ -164,7 +168,7 @@ def test_save_load_serialization(mode, tmp_path, train_theta, discretizer): 2, 3, tf.keras.layers.SimpleRNNCell(4), - train_theta=train_theta, + trainable_theta=trainable_theta, discretizer=discretizer, ), return_sequences=True, @@ -177,7 +181,7 @@ def test_save_load_serialization(mode, tmp_path, train_theta, discretizer): tf.keras.layers.SimpleRNNCell(4), return_sequences=True, memory_to_memory=True, - train_theta=train_theta, + trainable_theta=trainable_theta, discretizer=discretizer, )(inp) elif mode == "fft": @@ -186,7 +190,6 @@ def test_save_load_serialization(mode, tmp_path, train_theta, discretizer): 2, 3, tf.keras.layers.SimpleRNNCell(4), - train_theta=train_theta, discretizer=discretizer, return_sequences=True, )(inp) @@ -220,15 +223,13 @@ def test_save_load_serialization(mode, tmp_path, train_theta, discretizer): ) @pytest.mark.parametrize("memory_d", [1, 4]) @pytest.mark.parametrize("discretizer", ("zoh", "euler")) -@pytest.mark.parametrize("train_theta", (True, False)) -def test_fft(return_sequences, hidden_cell, memory_d, discretizer, train_theta, rng): +def test_fft(return_sequences, hidden_cell, memory_d, discretizer, rng): kwargs = dict( memory_d=memory_d, order=2, theta=12, hidden_cell=hidden_cell(), discretizer=discretizer, - train_theta=train_theta, ) x = rng.uniform(-1, 1, size=(2, 10, 32)) @@ -263,29 +264,30 @@ def test_validation_errors(): @pytest.mark.parametrize( - "hidden_to_memory, memory_to_memory, memory_d, steps", + "should_use_fft, hidden_to_memory, memory_to_memory, steps, trainable_theta", [ - (False, False, 1, 5), - (True, False, 1, 5), - (False, True, 1, 5), - (False, False, 2, 5), - (False, False, 1, None), + (True, False, False, 5, False), + (False, True, False, 5, False), + (False, False, True, 5, False), + (False, False, False, None, False), + (False, False, False, 5, True), ], ) -def test_fft_auto_swap(hidden_to_memory, memory_to_memory, memory_d, steps): +def test_fft_auto_swap( + should_use_fft, hidden_to_memory, memory_to_memory, steps, trainable_theta +): lmu = layers.LMU( - memory_d, + 4, 2, 3, tf.keras.layers.Dense(5), hidden_to_memory=hidden_to_memory, memory_to_memory=memory_to_memory, + trainable_theta=trainable_theta, ) lmu.build((32, steps, 8)) - assert isinstance(lmu.layer, tf.keras.layers.RNN) == ( - hidden_to_memory or memory_to_memory or steps is None - ) + assert isinstance(lmu.layer, layers.LMUFFT) == should_use_fft @pytest.mark.parametrize( @@ -430,17 +432,18 @@ def test_dropout( assert np.allclose(y0, y1) -@pytest.mark.parametrize("train_theta", (True, False)) +@pytest.mark.parametrize("trainable_theta", (True, False)) @pytest.mark.parametrize("discretizer", ("zoh", "euler")) @pytest.mark.parametrize("fft", (True, False)) -def test_fit(fft, discretizer, train_theta): - order = 256 - theta_init = (1 if discretizer == "zoh" else 12) * order +def test_fit(fft, discretizer, trainable_theta): + if fft and trainable_theta: + pytest.skip("FFT does not support trainable theta") + lmu_layer = layers.LMU( memory_d=1, - order=order, - theta=theta_init, - train_theta=train_theta, + order=256, + theta=784 if discretizer == "zoh" else 2000, + trainable_theta=trainable_theta, hidden_cell=tf.keras.layers.SimpleRNNCell(units=30), hidden_to_memory=not fft, memory_to_memory=not fft, @@ -497,118 +500,90 @@ def test_discretizer_types(): ) -def test_cont2discrete_zoh(): - lmu_cell = layers.LMUCell( - memory_d=5, order=64, theta=784, hidden_cell=None, discretizer="zoh" +@pytest.mark.parametrize("trainable_theta", (True, False)) +def test_discretizer_equivalence(trainable_theta, rng): + # check that zoh and euler produce approximately the same output + layer_zoh = layers.LMU( + memory_d=2, + order=8, + theta=256, + hidden_cell=None, + discretizer="zoh", + return_sequences=True, + kernel_initializer=None, + trainable_theta=trainable_theta, + ) + layer_euler = layers.LMU( + memory_d=2, + order=8, + theta=256, + hidden_cell=None, + discretizer="euler", + return_sequences=True, + kernel_initializer=None, + trainable_theta=trainable_theta, ) - lmu_cell.build((2, 40)) + x = rng.uniform(-1, 1, size=(32, 10, 2)) - A, B, _, _, _ = cont2discrete( - ( - lmu_cell.CONST_A * lmu_cell.theta_inv, - lmu_cell.CONST_B * lmu_cell.theta_inv, - np.ones(256), - np.ones(256), - ), - dt=1.0, - method="zoh", - ) + zoh = layer_zoh(x) + euler = layer_euler(x) - # slight numerical rounding differences between scipy and tf implementation - assert np.allclose(A.T, lmu_cell._A, atol=1e-6) - assert np.allclose(B.T, lmu_cell._B, atol=1e-6) + assert np.allclose(zoh, euler, atol=0.02), np.max(np.abs(zoh - euler)) -@pytest.mark.parametrize("hidden_cell", (None, tf.keras.layers.SimpleRNNCell(units=30))) -@pytest.mark.parametrize("discretizer", ("euler", "zoh")) -@pytest.mark.parametrize("train_theta", (True, False)) -def test_theta_AB_updates(discretizer, hidden_cell, train_theta, tmp_path): - # typically you need at least 4*order as minimum value of theta with euler - # but we increase it even further since we only have a single memory tape - # and the network is smaller, which might cause even more numerical - # instabilities. - order = 64 - theta_init = 10 * order +def test_cont2discrete_zoh(rng): + A = rng.randn(64, 64) + B = rng.randn(64, 1) + C = np.ones((1, 64)) + D = np.zeros((1,)) + + scipy_A, scipy_B, *_ = cont2discrete((A, B, C, D), dt=1.0, method="zoh") + tf_A, tf_B = layers.LMUCell._cont2discrete_zoh(A.T, B.T) + + assert np.allclose(scipy_A, tf.transpose(tf_A)) + assert np.allclose(scipy_B, tf.transpose(tf_B)) + +@pytest.mark.parametrize("discretizer", ("euler", "zoh")) +@pytest.mark.parametrize("trainable_theta", (True, False)) +def test_theta_update(discretizer, trainable_theta, tmp_path): # create model - lmu_layer = layers.LMU( - memory_d=1, - order=order, - theta=theta_init, - train_theta=train_theta, - hidden_cell=hidden_cell, + theta = 10 + lmu_cell = layers.LMUCell( + memory_d=2, + order=3, + theta=theta, + trainable_theta=trainable_theta, + hidden_cell=tf.keras.layers.SimpleRNNCell(units=4), discretizer=discretizer, - hidden_to_memory=hidden_cell is not None, - memory_to_memory=hidden_cell is not None, - input_to_hidden=hidden_cell is not None, ) inputs = tf.keras.layers.Input((None, 20)) - lmu = lmu_layer(inputs) - outputs = tf.keras.layers.Dense(10, activation="sigmoid")(lmu) - - model = tf.keras.Model(inputs=inputs, outputs=outputs) + lmu = tf.keras.layers.RNN(lmu_cell)(inputs) + model = tf.keras.Model(inputs=inputs, outputs=lmu) model.compile( - loss=tf.keras.losses.MeanSquaredError(), - optimizer=tf.keras.optimizers.Adam(), - metrics=["accuracy"], + loss=tf.keras.losses.MeanSquaredError(), optimizer=tf.keras.optimizers.Adam() ) - # determine index of weights - for index, weight in enumerate(model.layers[1].weights): - if "A" in weight.name: - index_A = index - if "B" in weight.name: - index_B = index - if "theta_kernel" in weight.name: - index_t = index - - # store initial value of A and B - A_init, B_init = ( - model.layers[1].weights[index_A].numpy(), - model.layers[1].weights[index_B].numpy(), - ) - - # make sure theta_inv is set correctly to initital value - assert 1 / theta_init == ( - model.layers[1].weights[index_t].numpy() - if train_theta - else (model.layers[1]).layer.cell.theta_inv - ) + # make sure theta_inv is set correctly to initial value + assert np.allclose(lmu_cell.theta_inv.numpy(), 1 / theta) # fit model on some data - x_train = tf.random.uniform((100, 5, 20), dtype=tf.float32) - y_train = tf.random.uniform((100, 10), dtype=tf.float32) + model.fit(tf.ones((64, 5, 20)), tf.ones((64, 4)), epochs=1) - model.fit(x_train, y_train, epochs=10, validation_split=0.2) + # make sure theta kernel got updated if trained + assert np.allclose(lmu_cell.theta_inv.numpy(), 1 / theta) != trainable_theta - # make sure A and B got updated if trained with zoh - assert np.any(A_init != model.layers[1].weights[index_A].numpy()) == ( - train_theta and discretizer == "zoh" - ) - assert np.any(B_init != model.layers[1].weights[index_B].numpy()) == ( - train_theta and discretizer == "zoh" - ) + # save model and make sure you get same outputs, that is, correct theta was stored + model.save(str(tmp_path)) - # make sure theta kernel got updated if trained - if train_theta: - assert model.layers[1].weights[index_t].numpy() != (1 / theta_init) - - # save model and make sure you get same outputs, that is, correct values of A - # B and theta were stored after training - if hidden_cell: - model.save(str(tmp_path)) - - model_load = tf.keras.models.load_model( - str(tmp_path), - custom_objects={ - "LMU": layers.LMU, - }, - ) + model_load = tf.keras.models.load_model( + str(tmp_path), custom_objects={"LMUCell": layers.LMUCell} + ) - assert np.allclose( - model.predict(np.ones((32, 10, 20))), - model_load.predict(np.ones((32, 10, 20))), - ) + assert np.allclose( + model.predict(np.ones((32, 10, 20))), + model_load.predict(np.ones((32, 10, 20))), + ) diff --git a/keras_lmu/version.py b/keras_lmu/version.py index eda5ed50..4b17531c 100644 --- a/keras_lmu/version.py +++ b/keras_lmu/version.py @@ -7,7 +7,7 @@ """ name = "keras_lmu" -version_info = (0, 3, 2) # (major, minor, patch) +version_info = (0, 4, 0) # (major, minor, patch) dev = 0 # set to None for releases version = (