From 5694c277c38865d8d41834b3dd58df40f2a5b708 Mon Sep 17 00:00:00 2001 From: Trevor Bekolay Date: Mon, 22 Apr 2019 21:16:59 -0400 Subject: [PATCH] squash! Combine splitter into builder Rename PassthroughDirective -> PassthroughSplit Like SplitterDirective, PassthroughDirective is a bit clunky, and definitely doesn't make sense if there is no SplitterDirective. This name attempts to make the relationship between Split and PassthroughSplit more clear, which is that a PassthroughSplit is a specific type of Split (though they are not a class/subclass relationship, so it's only a specific type in a conceptual sense). --- nengo_loihi/builder/builder.py | 12 +- nengo_loihi/builder/probe.py | 1 - nengo_loihi/passthrough.py | 181 +++++++++++++------------- nengo_loihi/simulator.py | 8 +- nengo_loihi/splitter.py | 15 +-- nengo_loihi/tests/test_passthrough.py | 50 ++++--- nengo_loihi/tests/test_splitter.py | 12 +- 7 files changed, 136 insertions(+), 143 deletions(-) diff --git a/nengo_loihi/builder/builder.py b/nengo_loihi/builder/builder.py index 9fd3fc270..3dfe3d839 100644 --- a/nengo_loihi/builder/builder.py +++ b/nengo_loihi/builder/builder.py @@ -1,7 +1,7 @@ from collections import defaultdict, OrderedDict import logging -from nengo import Network, Node, Ensemble, Connection, Probe +from nengo import Ensemble, Network, Node, Probe from nengo.builder import Model as NengoModel from nengo.builder.builder import Builder as NengoBuilder from nengo.builder.network import build_network @@ -201,13 +201,9 @@ def delegate(self, obj): return self.host def build(self, obj, *args, **kwargs): - # Don't build the passthrough nodes or connections - passthrough_directive = self.split.passthrough_directive - if (isinstance(obj, Node) - and obj in passthrough_directive.removed_passthroughs): - return None - if (isinstance(obj, Connection) - and obj in passthrough_directive.removed_connections): + # Don't build the objects marked as "to_remove" by PassthroughSplit + passthrough = self.split.passthrough + if obj in passthrough.to_remove: return None # Note: any callbacks for host_pre or host will not be invoked here diff --git a/nengo_loihi/builder/probe.py b/nengo_loihi/builder/probe.py index 9d9c049ae..c4b972dcd 100644 --- a/nengo_loihi/builder/probe.py +++ b/nengo_loihi/builder/probe.py @@ -1,6 +1,5 @@ import nengo from nengo import Ensemble, Connection, Node -from nengo.base import ObjView from nengo.connection import LearningRule from nengo.ensemble import Neurons from nengo.exceptions import BuildError diff --git a/nengo_loihi/passthrough.py b/nengo_loihi/passthrough.py index bb96944bf..726fbd5ea 100644 --- a/nengo_loihi/passthrough.py +++ b/nengo_loihi/passthrough.py @@ -1,4 +1,4 @@ -from collections import OrderedDict, namedtuple +from collections import OrderedDict import warnings from nengo import Connection, Lowpass, Node @@ -10,10 +10,6 @@ from nengo_loihi.compat import nengo_transforms, transform_array -PassthroughDirective = namedtuple( - "PassthroughDirective", - ["removed_passthroughs", "removed_connections", "added_connections"]) - def is_passthrough(obj): return isinstance(obj, Node) and obj.output is None @@ -216,51 +212,7 @@ def generate_conns(self): ) -def find_clusters(net, ignore): - """Create the Clusters for a given nengo Network.""" - - # find which objects have Probes, as we need to make sure to keep them - probed_objs = set(base_obj(p.target) for p in net.all_probes) - - clusters = OrderedDict() # mapping from object to its Cluster - for c in net.all_connections: - base_pre = base_obj(c.pre) - base_post = base_obj(c.post) - - pass_pre = is_passthrough(c.pre_obj) and c.pre_obj not in ignore - if pass_pre and c.pre_obj not in clusters: - # add new objects to their own initial Cluster - clusters[c.pre_obj] = Cluster(c.pre_obj) - if c.pre_obj in probed_objs: - clusters[c.pre_obj].probed_objs.add(c.pre_obj) - - pass_post = is_passthrough(c.post_obj) and c.post_obj not in ignore - if pass_post and c.post_obj not in clusters: - # add new objects to their own initial Cluster - clusters[c.post_obj] = Cluster(c.post_obj) - if c.post_obj in probed_objs: - clusters[c.post_obj].probed_objs.add(c.post_obj) - - if pass_pre and pass_post: - # both pre and post are passthrough, so merge the two - # clusters into one cluster - cluster = clusters[base_pre] - cluster.merge_with(clusters[base_post]) - for obj in cluster.objs: - clusters[obj] = cluster - cluster.conns_mid.add(c) - elif pass_pre: - # pre is passthrough but post is not, so this is an output - cluster = clusters[base_pre] - cluster.conns_out.add(c) - elif pass_post: - # pre is not a passthrough but post is, so this is an input - cluster = clusters[base_post] - cluster.conns_in.add(c) - return clusters - - -def convert_passthroughs(network, ignore): +class PassthroughSplit: """Create a set of Connections that could replace the passthrough Nodes. This does not actually modify the Network, but instead returns the @@ -273,40 +225,95 @@ def convert_passthroughs(network, ignore): The system will only remove passthrough Nodes where neither pre nor post are ignored. """ - clusters = find_clusters(network, ignore=ignore) - - removed_passthroughs = set() - removed_connections = set() - added_connections = set() - handled_clusters = set() - for cluster in clusters.values(): - if cluster not in handled_clusters: - handled_clusters.add(cluster) - onchip_input = False - onchip_output = False - for c in cluster.conns_in: - if base_obj(c.pre) not in ignore: - onchip_input = True - break - for c in cluster.conns_out: - if base_obj(c.post) not in ignore: - onchip_output = True - break - has_input = len(cluster.conns_in) > 0 - no_output = len(cluster.conns_out) + len(cluster.probed_objs) == 0 - - if has_input and ((onchip_input and onchip_output) or no_output): - try: - new_conns = list(cluster.generate_conns()) - except ClusterError: - # this Cluster has an issue, so don't remove it - continue - - removed_passthroughs.update(cluster.objs - cluster.probed_objs) - removed_connections.update(cluster.conns_in - | cluster.conns_mid - | cluster.conns_out) - added_connections.update(new_conns) - - return PassthroughDirective( - removed_passthroughs, removed_connections, added_connections) + + def __init__(self, network, ignore=None): + self.network = network + self.ignore = ignore if ignore is not None else set() + + self.to_remove = set() + self.to_add = set() + + if self.network is not None: + self.clusters = self._find_clusters() + self._already_split = set() + for cluster in self.clusters.values(): + if cluster not in self._already_split: + self._split_cluster(cluster) + + @property + def connections_to_remove(self): + return set(c for c in self.to_remove if isinstance(c, Connection)) + + @property + def nodes_to_remove(self): + return set(n for n in self.to_remove if isinstance(n, Node)) + + def _find_clusters(self): + """Find Clusters for the given Network.""" + + # find which objects have Probes, as we need to make sure to keep them + probed_objs = set(base_obj(p.target) for p in self.network.all_probes) + + clusters = OrderedDict() # mapping from object to its Cluster + for c in self.network.all_connections: + base_pre = base_obj(c.pre) + base_post = base_obj(c.post) + + pass_pre = (is_passthrough(c.pre_obj) + and c.pre_obj not in self.ignore) + if pass_pre and c.pre_obj not in clusters: + # add new objects to their own initial Cluster + clusters[c.pre_obj] = Cluster(c.pre_obj) + if c.pre_obj in probed_objs: + clusters[c.pre_obj].probed_objs.add(c.pre_obj) + + pass_post = (is_passthrough(c.post_obj) + and c.post_obj not in self.ignore) + if pass_post and c.post_obj not in clusters: + # add new objects to their own initial Cluster + clusters[c.post_obj] = Cluster(c.post_obj) + if c.post_obj in probed_objs: + clusters[c.post_obj].probed_objs.add(c.post_obj) + + if pass_pre and pass_post: + # both pre and post are passthrough, so merge the two + # clusters into one cluster + cluster = clusters[base_pre] + cluster.merge_with(clusters[base_post]) + for obj in cluster.objs: + clusters[obj] = cluster + cluster.conns_mid.add(c) + elif pass_pre: + # pre is passthrough but post is not, so this is an output + cluster = clusters[base_pre] + cluster.conns_out.add(c) + elif pass_post: + # pre is not a passthrough but post is, so this is an input + cluster = clusters[base_post] + cluster.conns_in.add(c) + return clusters + + def _split_cluster(self, cluster): + """Split a Cluster.""" + assert cluster not in self._already_split + self._already_split.add(cluster) + + onchip_input = any(base_obj(c.pre) not in self.ignore + for c in cluster.conns_in) + onchip_output = any(base_obj(c.post) not in self.ignore + for c in cluster.conns_out) + + has_input = len(cluster.conns_in) > 0 + no_output = len(cluster.conns_out) + len(cluster.probed_objs) == 0 + + if has_input and ((onchip_input and onchip_output) or no_output): + try: + new_conns = list(cluster.generate_conns()) + except ClusterError: + # this Cluster has an issue, so don't remove it + return + + self.to_remove.update(cluster.objs - cluster.probed_objs) + self.to_remove.update( + cluster.conns_in | cluster.conns_mid | cluster.conns_out) + self.to_add.update(new_conns) diff --git a/nengo_loihi/simulator.py b/nengo_loihi/simulator.py index 99c634bfd..dca0cd6ad 100644 --- a/nengo_loihi/simulator.py +++ b/nengo_loihi/simulator.py @@ -151,10 +151,10 @@ def __init__( # noqa: C901 self.model.build(network) # Build the extra passthrough connections into the model - passthrough_directive = ( - self.model.split.passthrough_directive) - for conn in passthrough_directive.added_connections: - # https://github.com/nengo/nengo-loihi/issues/210 + passthrough = self.model.split.passthrough + for conn in passthrough.to_add: + # Note: connections added by the passthrough splitter do not + # respect seeds self.model.seeds[conn] = None self.model.seeded[conn] = False self.model.build(conn) diff --git a/nengo_loihi/splitter.py b/nengo_loihi/splitter.py index b0f270d88..df9347639 100644 --- a/nengo_loihi/splitter.py +++ b/nengo_loihi/splitter.py @@ -4,8 +4,7 @@ from nengo.exceptions import BuildError from nengo.connection import LearningRule -from nengo_loihi.passthrough import ( - convert_passthroughs, PassthroughDirective, base_obj, is_passthrough) +from nengo_loihi.passthrough import base_obj, is_passthrough, PassthroughSplit class Split: @@ -51,10 +50,9 @@ def __init__(self, network, precompute=False, remove_passthrough=True): passthroughs = set( obj for obj in network.all_nodes if is_passthrough(obj)) ignore = self._seen_objects - self._chip_objects - passthroughs - self.passthrough_directive = convert_passthroughs(network, ignore) + self.passthrough = PassthroughSplit(network, ignore) else: - self.passthrough_directive = PassthroughDirective( - set(), set(), set()) + self.passthrough = PassthroughSplit(None) # Step 5. Split precomputable parts of host # This is a subset of host, marking which are precomputable @@ -86,8 +84,7 @@ def mark_precomputable(obj): # determine which connections will actually be built conns = ((set(self.network.all_connections) - | self.passthrough_directive.added_connections) - - self.passthrough_directive.removed_connections) + | self.passthrough.to_add) - self.passthrough.to_remove) # Initialize queue with the pre objects on host->chip connections. # We assume that all `conn.pre` objects are pre-computable, and then @@ -102,8 +99,8 @@ def mark_precomputable(obj): pre, post = base_obj(conn.pre), base_obj(conn.post) pre_to_conn[pre].append(conn) post_to_conn[post].append(conn) - assert pre not in self.passthrough_directive.removed_passthroughs - assert post not in self.passthrough_directive.removed_passthroughs + assert pre not in self.passthrough.to_remove + assert post not in self.passthrough.to_remove if (isinstance(post, LearningRule) or conn.learning_rule is not None): diff --git a/nengo_loihi/tests/test_passthrough.py b/nengo_loihi/tests/test_passthrough.py index 8e91b037a..a3df090b9 100644 --- a/nengo_loihi/tests/test_passthrough.py +++ b/nengo_loihi/tests/test_passthrough.py @@ -5,7 +5,7 @@ from nengo_loihi.compat import transform_array from nengo_loihi.decode_neurons import OnOffDecodeNeurons -from nengo_loihi.passthrough import convert_passthroughs +from nengo_loihi.passthrough import PassthroughSplit default_node_neurons = OnOffDecodeNeurons() @@ -29,16 +29,14 @@ def test_passthrough_placement(): nengo.Connection(f, g) nengo.Probe(g) - passthrough_directive = convert_passthroughs(model, ignore={stim}) + split = PassthroughSplit(model, ignore={stim}) - assert passthrough_directive.removed_passthroughs == {c, d, e} - assert passthrough_directive.removed_connections == { - conn_bc, conn_cd, conn_de, conn_ef} + assert split.to_remove == {c, d, e, conn_bc, conn_cd, conn_de, conn_ef} - conns = list(passthrough_directive.added_connections) - assert len(conns) == 1 - assert conns[0].pre is b - assert conns[0].post is f + assert len(split.to_add) == 1 + conn = next(iter(split.to_add)) + assert conn.pre is b + assert conn.post is f @pytest.mark.parametrize("d1", [1, 3]) @@ -56,14 +54,13 @@ def test_transform_merging(d1, d2, d3): conn_ab = nengo.Connection(a, b, transform=t1) conn_bc = nengo.Connection(b, c, transform=t2) - passthrough_directive = convert_passthroughs(model, ignore=set()) + split = PassthroughSplit(model) - assert passthrough_directive.removed_passthroughs == {b} - assert passthrough_directive.removed_connections == {conn_ab, conn_bc} + assert split.to_remove == {b, conn_ab, conn_bc} - conns = list(passthrough_directive.added_connections) - assert len(conns) == 1 - assert np.allclose(transform_array(conns[0].transform), np.dot(t2, t1)) + assert len(split.to_add) == 1 + conn = next(iter(split.to_add)) + assert np.allclose(transform_array(conn.transform), np.dot(t2, t1)) @pytest.mark.parametrize("n_ensembles", [1, 3]) @@ -74,14 +71,13 @@ def test_identity_array(n_ensembles, ens_dimensions): b = nengo.networks.EnsembleArray(10, n_ensembles, ens_dimensions) nengo.Connection(a.output, b.input) - passthrough_directive = convert_passthroughs(model, ignore=set()) + split = PassthroughSplit(model) - conns = list(passthrough_directive.added_connections) - assert len(conns) == n_ensembles + assert len(split.to_add) == n_ensembles pre = set() post = set() - for conn in conns: + for conn in split.to_add: assert conn.pre in a.all_ensembles or conn.pre_obj is a.input assert conn.post in b.all_ensembles assert np.allclose(transform_array(conn.transform), @@ -101,13 +97,12 @@ def test_full_array(n_ensembles, ens_dimensions): D = n_ensembles * ens_dimensions nengo.Connection(a.output, b.input, transform=np.ones((D, D))) - passthrough_directive = convert_passthroughs(model, ignore=set()) + split = PassthroughSplit(model) - conns = list(passthrough_directive.added_connections) - assert len(conns) == n_ensembles ** 2 + assert len(split.to_add) == n_ensembles ** 2 pairs = set() - for conn in conns: + for conn in split.to_add: assert conn.pre in a.all_ensembles assert conn.post in b.all_ensembles assert np.allclose(transform_array(conn.transform), @@ -128,10 +123,9 @@ def test_synapse_merging(Simulator, seed): nengo.Connection(b[1], c.input[0], synapse=None) nengo.Connection(b[1], c.input[1], synapse=0.2) - passthrough_directive = convert_passthroughs(model, ignore=set()) + split = PassthroughSplit(model) - conns = list(passthrough_directive.added_connections) - assert len(conns) == 4 + assert len(split.to_add) == 4 desired_filters = { ('0', '0'): None, @@ -139,7 +133,7 @@ def test_synapse_merging(Simulator, seed): ('1', '0'): 0.1, ('1', '1'): 0.3, } - for conn in conns: + for conn in split.to_add: if desired_filters[(conn.pre.label, conn.post.label)] is None: assert conn.synapse is None else: @@ -234,7 +228,7 @@ def make_net(learn_error=False, loop=False): return net, probes - # Since `convert_passthroughs` catches its own cluster errors, we won't see + # Since `PassthroughSplit` catches its own cluster errors, we won't see # the error here. We ensure identical behaviour (so nodes are not removed). # Test learning rule node input diff --git a/nengo_loihi/tests/test_splitter.py b/nengo_loihi/tests/test_splitter.py index 96ad128ae..9c281748b 100644 --- a/nengo_loihi/tests/test_splitter.py +++ b/nengo_loihi/tests/test_splitter.py @@ -246,20 +246,20 @@ def test_split_remove_passthrough(remove_passthrough): split = Split(net, remove_passthrough=remove_passthrough) assert not split.on_chip(probe) - pd = split.passthrough_directive - if remove_passthrough: - assert pd.removed_passthroughs == {discard1, discard2} - assert pd.removed_connections == {conn1, conn2, conn3, conn4} + assert split.passthrough.to_remove == { + discard1, discard2, conn1, conn2, conn3, conn4, + } - conns = list(pd.added_connections) + conns = list(split.passthrough.to_add) assert len(conns) == 2 prepost = {(conn.pre, conn.post) for conn in conns} assert prepost == {(chip1, chip2), (chip2, chip3)} else: - assert pd == (set(), set(), set()) + assert split.passthrough.to_remove == set() + assert split.passthrough.to_add == set() def test_precompute_remove_passthrough():