Skip to content

Commit

Permalink
squash! Combine splitter into builder
Browse files Browse the repository at this point in the history
Rename PassthroughDirective -> PassthroughSplit

Like SplitterDirective, PassthroughDirective is a bit clunky, and
definitely doesn't make sense if there is no SplitterDirective.
This name attempts to make the relationship between Split and
PassthroughSplit more clear, which is that a PassthroughSplit is
a specific type of Split (though they are not a class/subclass
relationship, so it's only a specific type in a conceptual sense).
  • Loading branch information
tbekolay committed Apr 23, 2019
1 parent a2243d0 commit 5694c27
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 143 deletions.
12 changes: 4 additions & 8 deletions nengo_loihi/builder/builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict, OrderedDict
import logging

from nengo import Network, Node, Ensemble, Connection, Probe
from nengo import Ensemble, Network, Node, Probe
from nengo.builder import Model as NengoModel
from nengo.builder.builder import Builder as NengoBuilder
from nengo.builder.network import build_network
Expand Down Expand Up @@ -201,13 +201,9 @@ def delegate(self, obj):
return self.host

def build(self, obj, *args, **kwargs):
# Don't build the passthrough nodes or connections
passthrough_directive = self.split.passthrough_directive
if (isinstance(obj, Node)
and obj in passthrough_directive.removed_passthroughs):
return None
if (isinstance(obj, Connection)
and obj in passthrough_directive.removed_connections):
# Don't build the objects marked as "to_remove" by PassthroughSplit
passthrough = self.split.passthrough
if obj in passthrough.to_remove:
return None

# Note: any callbacks for host_pre or host will not be invoked here
Expand Down
1 change: 0 additions & 1 deletion nengo_loihi/builder/probe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import nengo
from nengo import Ensemble, Connection, Node
from nengo.base import ObjView
from nengo.connection import LearningRule
from nengo.ensemble import Neurons
from nengo.exceptions import BuildError
Expand Down
181 changes: 94 additions & 87 deletions nengo_loihi/passthrough.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import OrderedDict, namedtuple
from collections import OrderedDict
import warnings

from nengo import Connection, Lowpass, Node
Expand All @@ -10,10 +10,6 @@

from nengo_loihi.compat import nengo_transforms, transform_array

PassthroughDirective = namedtuple(
"PassthroughDirective",
["removed_passthroughs", "removed_connections", "added_connections"])


def is_passthrough(obj):
return isinstance(obj, Node) and obj.output is None
Expand Down Expand Up @@ -216,51 +212,7 @@ def generate_conns(self):
)


def find_clusters(net, ignore):
"""Create the Clusters for a given nengo Network."""

# find which objects have Probes, as we need to make sure to keep them
probed_objs = set(base_obj(p.target) for p in net.all_probes)

clusters = OrderedDict() # mapping from object to its Cluster
for c in net.all_connections:
base_pre = base_obj(c.pre)
base_post = base_obj(c.post)

pass_pre = is_passthrough(c.pre_obj) and c.pre_obj not in ignore
if pass_pre and c.pre_obj not in clusters:
# add new objects to their own initial Cluster
clusters[c.pre_obj] = Cluster(c.pre_obj)
if c.pre_obj in probed_objs:
clusters[c.pre_obj].probed_objs.add(c.pre_obj)

pass_post = is_passthrough(c.post_obj) and c.post_obj not in ignore
if pass_post and c.post_obj not in clusters:
# add new objects to their own initial Cluster
clusters[c.post_obj] = Cluster(c.post_obj)
if c.post_obj in probed_objs:
clusters[c.post_obj].probed_objs.add(c.post_obj)

if pass_pre and pass_post:
# both pre and post are passthrough, so merge the two
# clusters into one cluster
cluster = clusters[base_pre]
cluster.merge_with(clusters[base_post])
for obj in cluster.objs:
clusters[obj] = cluster
cluster.conns_mid.add(c)
elif pass_pre:
# pre is passthrough but post is not, so this is an output
cluster = clusters[base_pre]
cluster.conns_out.add(c)
elif pass_post:
# pre is not a passthrough but post is, so this is an input
cluster = clusters[base_post]
cluster.conns_in.add(c)
return clusters


def convert_passthroughs(network, ignore):
class PassthroughSplit:
"""Create a set of Connections that could replace the passthrough Nodes.
This does not actually modify the Network, but instead returns the
Expand All @@ -273,40 +225,95 @@ def convert_passthroughs(network, ignore):
The system will only remove passthrough Nodes where neither pre nor post
are ignored.
"""
clusters = find_clusters(network, ignore=ignore)

removed_passthroughs = set()
removed_connections = set()
added_connections = set()
handled_clusters = set()
for cluster in clusters.values():
if cluster not in handled_clusters:
handled_clusters.add(cluster)
onchip_input = False
onchip_output = False
for c in cluster.conns_in:
if base_obj(c.pre) not in ignore:
onchip_input = True
break
for c in cluster.conns_out:
if base_obj(c.post) not in ignore:
onchip_output = True
break
has_input = len(cluster.conns_in) > 0
no_output = len(cluster.conns_out) + len(cluster.probed_objs) == 0

if has_input and ((onchip_input and onchip_output) or no_output):
try:
new_conns = list(cluster.generate_conns())
except ClusterError:
# this Cluster has an issue, so don't remove it
continue

removed_passthroughs.update(cluster.objs - cluster.probed_objs)
removed_connections.update(cluster.conns_in
| cluster.conns_mid
| cluster.conns_out)
added_connections.update(new_conns)

return PassthroughDirective(
removed_passthroughs, removed_connections, added_connections)

def __init__(self, network, ignore=None):
self.network = network
self.ignore = ignore if ignore is not None else set()

self.to_remove = set()
self.to_add = set()

if self.network is not None:
self.clusters = self._find_clusters()
self._already_split = set()
for cluster in self.clusters.values():
if cluster not in self._already_split:
self._split_cluster(cluster)

@property
def connections_to_remove(self):
return set(c for c in self.to_remove if isinstance(c, Connection))

@property
def nodes_to_remove(self):
return set(n for n in self.to_remove if isinstance(n, Node))

def _find_clusters(self):
"""Find Clusters for the given Network."""

# find which objects have Probes, as we need to make sure to keep them
probed_objs = set(base_obj(p.target) for p in self.network.all_probes)

clusters = OrderedDict() # mapping from object to its Cluster
for c in self.network.all_connections:
base_pre = base_obj(c.pre)
base_post = base_obj(c.post)

This comment has been minimized.

Copy link
@arvoelke

arvoelke Apr 23, 2019

Contributor

I've had to remind myself of this a couple times now (for all of the base_obj calls and pre_obj references) of the following fact, so it might be worth adding a comment to the effect of: We assume neither pre nor post can be a probe (since they are end-points of a connection), and so we do not need to resolve any underlying .target objects here.

pass_pre = (is_passthrough(c.pre_obj)
and c.pre_obj not in self.ignore)
if pass_pre and c.pre_obj not in clusters:
# add new objects to their own initial Cluster
clusters[c.pre_obj] = Cluster(c.pre_obj)
if c.pre_obj in probed_objs:
clusters[c.pre_obj].probed_objs.add(c.pre_obj)

pass_post = (is_passthrough(c.post_obj)
and c.post_obj not in self.ignore)
if pass_post and c.post_obj not in clusters:
# add new objects to their own initial Cluster
clusters[c.post_obj] = Cluster(c.post_obj)
if c.post_obj in probed_objs:
clusters[c.post_obj].probed_objs.add(c.post_obj)

if pass_pre and pass_post:
# both pre and post are passthrough, so merge the two
# clusters into one cluster
cluster = clusters[base_pre]
cluster.merge_with(clusters[base_post])
for obj in cluster.objs:
clusters[obj] = cluster
cluster.conns_mid.add(c)
elif pass_pre:
# pre is passthrough but post is not, so this is an output
cluster = clusters[base_pre]
cluster.conns_out.add(c)
elif pass_post:
# pre is not a passthrough but post is, so this is an input
cluster = clusters[base_post]
cluster.conns_in.add(c)
return clusters

def _split_cluster(self, cluster):
"""Split a Cluster."""
assert cluster not in self._already_split
self._already_split.add(cluster)

onchip_input = any(base_obj(c.pre) not in self.ignore
for c in cluster.conns_in)
onchip_output = any(base_obj(c.post) not in self.ignore
for c in cluster.conns_out)

has_input = len(cluster.conns_in) > 0
no_output = len(cluster.conns_out) + len(cluster.probed_objs) == 0

if has_input and ((onchip_input and onchip_output) or no_output):
try:
new_conns = list(cluster.generate_conns())
except ClusterError:
# this Cluster has an issue, so don't remove it
return

self.to_remove.update(cluster.objs - cluster.probed_objs)
self.to_remove.update(
cluster.conns_in | cluster.conns_mid | cluster.conns_out)
self.to_add.update(new_conns)
8 changes: 4 additions & 4 deletions nengo_loihi/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ def __init__( # noqa: C901
self.model.build(network)

# Build the extra passthrough connections into the model
passthrough_directive = (
self.model.split.passthrough_directive)
for conn in passthrough_directive.added_connections:
# https://github.com/nengo/nengo-loihi/issues/210
passthrough = self.model.split.passthrough
for conn in passthrough.to_add:
# Note: connections added by the passthrough splitter do not
# respect seeds
self.model.seeds[conn] = None
self.model.seeded[conn] = False
self.model.build(conn)
Expand Down
15 changes: 6 additions & 9 deletions nengo_loihi/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from nengo.exceptions import BuildError
from nengo.connection import LearningRule

from nengo_loihi.passthrough import (
convert_passthroughs, PassthroughDirective, base_obj, is_passthrough)
from nengo_loihi.passthrough import base_obj, is_passthrough, PassthroughSplit


class Split:
Expand Down Expand Up @@ -51,10 +50,9 @@ def __init__(self, network, precompute=False, remove_passthrough=True):
passthroughs = set(
obj for obj in network.all_nodes if is_passthrough(obj))
ignore = self._seen_objects - self._chip_objects - passthroughs
self.passthrough_directive = convert_passthroughs(network, ignore)
self.passthrough = PassthroughSplit(network, ignore)
else:
self.passthrough_directive = PassthroughDirective(
set(), set(), set())
self.passthrough = PassthroughSplit(None)

# Step 5. Split precomputable parts of host
# This is a subset of host, marking which are precomputable
Expand Down Expand Up @@ -86,8 +84,7 @@ def mark_precomputable(obj):

# determine which connections will actually be built
conns = ((set(self.network.all_connections)
| self.passthrough_directive.added_connections)
- self.passthrough_directive.removed_connections)
| self.passthrough.to_add) - self.passthrough.to_remove)

# Initialize queue with the pre objects on host->chip connections.
# We assume that all `conn.pre` objects are pre-computable, and then
Expand All @@ -102,8 +99,8 @@ def mark_precomputable(obj):
pre, post = base_obj(conn.pre), base_obj(conn.post)
pre_to_conn[pre].append(conn)
post_to_conn[post].append(conn)
assert pre not in self.passthrough_directive.removed_passthroughs
assert post not in self.passthrough_directive.removed_passthroughs
assert pre not in self.passthrough.to_remove
assert post not in self.passthrough.to_remove

if (isinstance(post, LearningRule)
or conn.learning_rule is not None):
Expand Down
Loading

0 comments on commit 5694c27

Please sign in to comment.