From dd6f5263a696c38f6da28bd3a4201822216f05f7 Mon Sep 17 00:00:00 2001 From: Aaron Voelker Date: Fri, 29 Mar 2019 14:13:26 -0400 Subject: [PATCH] Raise BuildError if learning objects are on_chip Fixes #208 and #209. --- nengo_loihi/new_splitter.py | 8 ++++++++ nengo_loihi/tests/test_new_splitter.py | 28 ++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/nengo_loihi/new_splitter.py b/nengo_loihi/new_splitter.py index e03246659..2e6e5bcf0 100644 --- a/nengo_loihi/new_splitter.py +++ b/nengo_loihi/new_splitter.py @@ -50,10 +50,18 @@ def __init__(self, network, precompute=False, remove_passthrough=True): if (conn.learning_rule_type is not None and isinstance(post, Ensemble) and post in self._chip_objects): + if network.config[post].on_chip: + raise BuildError("Post ensemble (%r) of learned " + "connection (%r) must not be configured " + "as on_chip." % (post, conn)) self._chip_objects.remove(post) elif (isinstance(post, LearningRule) and isinstance(pre, Ensemble) and pre in self._chip_objects): + if network.config[pre].on_chip: + raise BuildError("Pre ensemble (%r) of error " + "connection (%r) must not be configured " + "as on_chip." % (pre, conn)) self._chip_objects.remove(pre) # Step 4. Mark passthrough nodes for removal diff --git a/nengo_loihi/tests/test_new_splitter.py b/nengo_loihi/tests/test_new_splitter.py index 997645c9f..35f409ccd 100644 --- a/nengo_loihi/tests/test_new_splitter.py +++ b/nengo_loihi/tests/test_new_splitter.py @@ -216,3 +216,31 @@ def test_already_moved_to_host(): splitter_directive = SplitterDirective(net) with pytest.raises(ValueError, match="must be on chip"): splitter_directive.move_to_host(u) + + +def test_chip_learning_errors(): + with nengo.Network() as net: + add_params(net) + + a = nengo.Ensemble(100, 1) + b = nengo.Ensemble(100, 1) + net.config[b].on_chip = True + + nengo.Connection(a, b, learning_rule_type=nengo.PES()) + + with pytest.raises(BuildError, match="Post ensemble"): + SplitterDirective(net) + + with nengo.Network() as net: + add_params(net) + + a = nengo.Ensemble(100, 1) + b = nengo.Ensemble(100, 1) + error = nengo.Ensemble(100, 1) + net.config[error].on_chip = True + + conn = nengo.Connection(a, b, learning_rule_type=nengo.PES()) + nengo.Connection(error, conn.learning_rule) + + with pytest.raises(BuildError, match="Pre ensemble"): + SplitterDirective(net)