Skip to content

Commit

Permalink
Raise BuildError if learning objects are on_chip
Browse files Browse the repository at this point in the history
Fixes #208 and #209.
  • Loading branch information
arvoelke authored and tbekolay committed Apr 15, 2019
1 parent a2be2de commit b55dd34
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
8 changes: 8 additions & 0 deletions nengo_loihi/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,18 @@ def __init__(self, network, precompute=False, remove_passthrough=True):
if (conn.learning_rule_type is not None
and isinstance(post, Ensemble)
and post in self._chip_objects):
if network.config[post].on_chip:
raise BuildError("Post ensemble (%r) of learned "
"connection (%r) must not be configured "
"as on_chip." % (post, conn))
self._chip_objects.remove(post)
elif (isinstance(post, LearningRule)
and isinstance(pre, Ensemble)
and pre in self._chip_objects):
if network.config[pre].on_chip:
raise BuildError("Pre ensemble (%r) of error "
"connection (%r) must not be configured "
"as on_chip." % (pre, conn))
self._chip_objects.remove(pre)

# Step 4. Mark passthrough nodes for removal
Expand Down
28 changes: 28 additions & 0 deletions nengo_loihi/tests/test_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,34 @@ def test_split_precompute_loop_error():
SplitterDirective(net, precompute=True)


def test_chip_learning_errors():
with nengo.Network() as net:
add_params(net)

a = nengo.Ensemble(100, 1)
b = nengo.Ensemble(100, 1)
net.config[b].on_chip = True

nengo.Connection(a, b, learning_rule_type=nengo.PES())

with pytest.raises(BuildError, match="Post ensemble"):
SplitterDirective(net)

with nengo.Network() as net:
add_params(net)

a = nengo.Ensemble(100, 1)
b = nengo.Ensemble(100, 1)
error = nengo.Ensemble(100, 1)
net.config[error].on_chip = True

conn = nengo.Connection(a, b, learning_rule_type=nengo.PES())
nengo.Connection(error, conn.learning_rule)

with pytest.raises(BuildError, match="Pre ensemble"):
SplitterDirective(net)


@pytest.mark.parametrize("remove_passthrough", [True, False])
def test_split_remove_passthrough(remove_passthrough):
with nengo.Network() as net:
Expand Down

0 comments on commit b55dd34

Please sign in to comment.