diff --git a/nengo_loihi/splitter.py b/nengo_loihi/splitter.py index d7b816b4d..93aa74f5e 100644 --- a/nengo_loihi/splitter.py +++ b/nengo_loihi/splitter.py @@ -40,10 +40,18 @@ def __init__(self, network, precompute=False, remove_passthrough=True): if (conn.learning_rule_type is not None and isinstance(post, Ensemble) and post in self._chip_objects): + if network.config[post].on_chip: + raise BuildError("Post ensemble (%r) of learned " + "connection (%r) must not be configured " + "as on_chip." % (post, conn)) self._chip_objects.remove(post) elif (isinstance(post, LearningRule) and isinstance(pre, Ensemble) and pre in self._chip_objects): + if network.config[pre].on_chip: + raise BuildError("Pre ensemble (%r) of error " + "connection (%r) must not be configured " + "as on_chip." % (pre, conn)) self._chip_objects.remove(pre) # Step 4. Mark passthrough nodes for removal diff --git a/nengo_loihi/tests/test_splitter.py b/nengo_loihi/tests/test_splitter.py index 456b15190..fa66ef823 100644 --- a/nengo_loihi/tests/test_splitter.py +++ b/nengo_loihi/tests/test_splitter.py @@ -216,6 +216,34 @@ def test_split_precompute_loop_error(): SplitterDirective(net, precompute=True) +def test_chip_learning_errors(): + with nengo.Network() as net: + add_params(net) + + a = nengo.Ensemble(100, 1) + b = nengo.Ensemble(100, 1) + net.config[b].on_chip = True + + nengo.Connection(a, b, learning_rule_type=nengo.PES()) + + with pytest.raises(BuildError, match="Post ensemble"): + SplitterDirective(net) + + with nengo.Network() as net: + add_params(net) + + a = nengo.Ensemble(100, 1) + b = nengo.Ensemble(100, 1) + error = nengo.Ensemble(100, 1) + net.config[error].on_chip = True + + conn = nengo.Connection(a, b, learning_rule_type=nengo.PES()) + nengo.Connection(error, conn.learning_rule) + + with pytest.raises(BuildError, match="Pre ensemble"): + SplitterDirective(net) + + @pytest.mark.parametrize("remove_passthrough", [True, False]) def test_split_remove_passthrough(remove_passthrough): with nengo.Network() as net: