Skip to content

Commit

Permalink
squash! Combine splitter into builder
Browse files Browse the repository at this point in the history
Rename SplitterDirective -> Split

SplitterDirective is a bit of a clunky name. I considered keeping
it as Splitter, but I get the concern that this class isn't actually
doing any splitting per se, it only represents information about
how to split a model. I chose "Split" as being the result of a
splitting process (so a "splitter" produces a "split").
  • Loading branch information
tbekolay committed Apr 23, 2019
1 parent a120c31 commit a2243d0
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 69 deletions.
4 changes: 2 additions & 2 deletions nengo_loihi/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def create_host_model(label, dt):
self.blocks = OrderedDict()

# Will be filled in by the simulator __init__
self.splitter_directive = None
self.split = None

# Will be filled in by the network builder
self.toplevel = None
Expand Down Expand Up @@ -202,7 +202,7 @@ def delegate(self, obj):

def build(self, obj, *args, **kwargs):
# Don't build the passthrough nodes or connections
passthrough_directive = self.splitter_directive.passthrough_directive
passthrough_directive = self.split.passthrough_directive
if (isinstance(obj, Node)
and obj in passthrough_directive.removed_passthroughs):
return None
Expand Down
4 changes: 2 additions & 2 deletions nengo_loihi/builder/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ def _inherit_seed(dest_model, dest_obj, src_model, src_obj):

@Builder.register(Connection)
def build_connection(model, conn):
is_pre_chip = model.splitter_directive.on_chip(base_obj(conn.pre))
is_pre_chip = model.split.on_chip(base_obj(conn.pre))

if isinstance(conn.post_obj, LearningRule):
assert not is_pre_chip
return build_host_to_learning_rule(model, conn)

is_post_chip = model.splitter_directive.on_chip(base_obj(conn.post))
is_post_chip = model.split.on_chip(base_obj(conn.post))

if is_pre_chip and is_post_chip:
build_chip_connection(model, conn)
Expand Down
4 changes: 2 additions & 2 deletions nengo_loihi/builder/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def conn_probe(model, nengo_probe):
target = nengo.Node(size_in=output_dim, add_to_container=False)
# TODO: This is a hack so that the builder can properly delegate the
# connection build to the right method
model.splitter_directive._seen_objects.add(target)
model.splitter_directive._chip_objects.add(target)
model.split._seen_objects.add(target)
model.split._chip_objects.add(target)

conn = Connection(
nengo_probe.target,
Expand Down
11 changes: 5 additions & 6 deletions nengo_loihi/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from nengo_loihi.discretize import discretize_model
from nengo_loihi.emulator import EmulatorInterface
from nengo_loihi.hardware import HardwareInterface, HAS_NXSDK
from nengo_loihi.splitter import SplitterDirective
from nengo_loihi.splitter import Split

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -143,17 +143,16 @@ def __init__( # noqa: C901
seeded=self.model.seeded)

# determine how to split the host into one, two or three models
self.model.splitter_directive = SplitterDirective(
network,
precompute=precompute,
remove_passthrough=remove_passthrough)
self.model.split = Split(network,
precompute=precompute,
remove_passthrough=remove_passthrough)

# Build the network into the model
self.model.build(network)

# Build the extra passthrough connections into the model
passthrough_directive = (
self.model.splitter_directive.passthrough_directive)
self.model.split.passthrough_directive)
for conn in passthrough_directive.added_connections:
# https://github.com/nengo/nengo-loihi/issues/210
self.model.seeds[conn] = None
Expand Down
2 changes: 1 addition & 1 deletion nengo_loihi/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
convert_passthroughs, PassthroughDirective, base_obj, is_passthrough)


class SplitterDirective:
class Split:
"""Creates a set of directives to guide the builder."""

def __init__(self, network, precompute=False, remove_passthrough=True):
Expand Down
110 changes: 54 additions & 56 deletions nengo_loihi/tests/test_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from nengo_loihi.config import add_params
from nengo_loihi.splitter import SplitterDirective
from nengo_loihi.splitter import Split


def test_place_nodes():
Expand All @@ -23,13 +23,13 @@ def test_place_nodes():
with nengo.Network():
nowhere = nengo.Node(0)

splitter_directive = SplitterDirective(net)
assert not splitter_directive.on_chip(offchip1)
assert not splitter_directive.on_chip(offchip2)
assert not splitter_directive.on_chip(offchip3)
split = Split(net)
assert not split.on_chip(offchip1)
assert not split.on_chip(offchip2)
assert not split.on_chip(offchip3)

with pytest.raises(IndexError, match="not a part of the network"):
splitter_directive.on_chip(nowhere)
split.on_chip(nowhere)


def test_place_ensembles():
Expand All @@ -49,19 +49,19 @@ def test_place_ensembles():
conn = nengo.Connection(pre, post, learning_rule_type=nengo.PES())
nengo.Connection(error, conn.learning_rule)

splitter_directive = SplitterDirective(net)
assert not splitter_directive.on_chip(offchip)
assert not splitter_directive.on_chip(direct)
assert splitter_directive.on_chip(onchip)
assert splitter_directive.on_chip(pre)
assert not splitter_directive.on_chip(post)
assert not splitter_directive.on_chip(error)
split = Split(net)
assert not split.on_chip(offchip)
assert not split.on_chip(direct)
assert split.on_chip(onchip)
assert split.on_chip(pre)
assert not split.on_chip(post)
assert not split.on_chip(error)

for obj in net.all_ensembles + net.all_nodes:
assert not splitter_directive.is_precomputable(obj)
assert not split.is_precomputable(obj)

with pytest.raises(TypeError, match="Locations are only established"):
splitter_directive.on_chip(conn)
split.on_chip(conn)


def test_place_internetwork_connections():
Expand All @@ -76,19 +76,19 @@ def test_place_internetwork_connections():
offon = nengo.Connection(offchip, onchip)
offoff = nengo.Connection(offchip, offchip)

splitter_directive = SplitterDirective(net)
split = Split(net)

assert splitter_directive.on_chip(onon.pre)
assert splitter_directive.on_chip(onon.post)
assert split.on_chip(onon.pre)
assert split.on_chip(onon.post)

assert splitter_directive.on_chip(onoff.pre)
assert not splitter_directive.on_chip(onoff.post)
assert split.on_chip(onoff.pre)
assert not split.on_chip(onoff.post)

assert not splitter_directive.on_chip(offon.pre)
assert splitter_directive.on_chip(offon.post)
assert not split.on_chip(offon.pre)
assert split.on_chip(offon.post)

assert not splitter_directive.on_chip(offoff.pre)
assert not splitter_directive.on_chip(offoff.post)
assert not split.on_chip(offoff.pre)
assert not split.on_chip(offoff.post)


def test_split_host_to_learning_rule():
Expand All @@ -109,13 +109,13 @@ def test_split_host_to_learning_rule():
nengo.Connection(
err_offchip, neurons_conn.learning_rule)

splitter_directive = SplitterDirective(net)
split = Split(net)

assert splitter_directive.on_chip(pre)
assert not splitter_directive.on_chip(post)
assert split.on_chip(pre)
assert not split.on_chip(post)

assert not splitter_directive.on_chip(err_onchip)
assert not splitter_directive.on_chip(err_offchip)
assert not split.on_chip(err_onchip)
assert not split.on_chip(err_offchip)


def test_precompute_host_to_learning_rule_unsupported():
Expand All @@ -127,7 +127,7 @@ def test_precompute_host_to_learning_rule_unsupported():
nengo.Connection(pre, post, learning_rule_type=nengo.PES())

with pytest.raises(BuildError, match="learning rules"):
SplitterDirective(net, precompute=True)
Split(net, precompute=True)


def test_place_probes():
Expand All @@ -150,13 +150,13 @@ def test_place_probes():
nengo.Probe(onchip2),
]

splitter_directive = SplitterDirective(net)
assert splitter_directive.on_chip(onchip1)
assert splitter_directive.on_chip(onchip2)
assert not splitter_directive.on_chip(offchip1)
assert not splitter_directive.on_chip(offchip2)
assert not any(splitter_directive.on_chip(p) for p in offchip_probes)
assert all(splitter_directive.on_chip(p) for p in onchip_probes)
split = Split(net)
assert split.on_chip(onchip1)
assert split.on_chip(onchip2)
assert not split.on_chip(offchip1)
assert not split.on_chip(offchip2)
assert not any(split.on_chip(p) for p in offchip_probes)
assert all(split.on_chip(p) for p in onchip_probes)


def test_split_pre_from_host():
Expand Down Expand Up @@ -184,23 +184,23 @@ def test_split_pre_from_host():
net.config[pre_4].on_chip = False
net.config[post1].on_chip = False

splitter_directive = SplitterDirective(net, precompute=True)
split = Split(net, precompute=True)

host_precomputable = {pre_1, pre_2, pre_3, pre_4, pre_5}
for obj in host_precomputable:
assert not splitter_directive.on_chip(obj)
assert splitter_directive.is_precomputable(obj)
assert not split.on_chip(obj)
assert split.is_precomputable(obj)

host_nonprecomputable = {post1, post2, post3}
for obj in host_nonprecomputable:
assert not splitter_directive.on_chip(obj)
assert not splitter_directive.is_precomputable(obj)
assert not split.on_chip(obj)
assert not split.is_precomputable(obj)

assert splitter_directive.on_chip(onchip)
assert not splitter_directive.is_precomputable(onchip)
assert split.on_chip(onchip)
assert not split.is_precomputable(onchip)

with pytest.raises(IndexError, match="not a part of the network"):
splitter_directive.is_precomputable(
split.is_precomputable(
nengo.Node(0, add_to_container=False))


Expand All @@ -213,7 +213,7 @@ def test_split_precompute_loop_error():
nengo.Connection(ens_onchip, node_offchip)

with pytest.raises(BuildError, match="Cannot precompute"):
SplitterDirective(net, precompute=True)
Split(net, precompute=True)


@pytest.mark.parametrize("remove_passthrough", [True, False])
Expand Down Expand Up @@ -243,11 +243,10 @@ def test_split_remove_passthrough(remove_passthrough):
conn4 = nengo.Connection(discard2, chip3)
nengo.Connection(chip3, keep4)

splitter_directive = SplitterDirective(
net, remove_passthrough=remove_passthrough)
assert not splitter_directive.on_chip(probe)
split = Split(net, remove_passthrough=remove_passthrough)
assert not split.on_chip(probe)

pd = splitter_directive.passthrough_directive
pd = split.passthrough_directive

if remove_passthrough:
assert pd.removed_passthroughs == {discard1, discard2}
Expand Down Expand Up @@ -280,14 +279,13 @@ def test_precompute_remove_passthrough():
nengo.Connection(onchip2, passthrough2)
nengo.Connection(passthrough2, onchip3)

splitter_directive = SplitterDirective(
net, precompute=True, remove_passthrough=True)
split = Split(net, precompute=True, remove_passthrough=True)

assert splitter_directive.is_precomputable(host)
assert not splitter_directive.on_chip(host)
assert split.is_precomputable(host)
assert not split.on_chip(host)

for obj in (onchip1, passthrough1, onchip2, passthrough2, onchip3):
assert not splitter_directive.is_precomputable(obj)
assert not split.is_precomputable(obj)

for obj in (onchip1, onchip2, onchip3):
assert splitter_directive.on_chip(obj)
assert split.on_chip(obj)

0 comments on commit a2243d0

Please sign in to comment.