Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic support for nengo_spa. #891

Merged
merged 3 commits into from
May 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Release History
0.3.2 (unreleased)
==================

- Support for nengo_spa
- Added --browser option
- Added --unsecure option
- Fixed backspace not working on sliders, search box
Expand Down
40 changes: 30 additions & 10 deletions nengo_gui/components/pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

import nengo
import nengo.spa as spa
try:
import nengo_spa
from nengo_spa.examine import pairs
except ImportError:
nengo_spa = None
import numpy as np

from nengo_gui.components.component import Component
Expand Down Expand Up @@ -31,14 +36,19 @@ def __init__(self, obj, **kwargs):
# neural activity of the population, rather than just changing
# the output.
self.loop_in_whitelist = [spa.Buffer, spa.Memory, spa.State]
if nengo_spa is not None:
self.loop_in_whitelist.extend([nengo_spa.State])

self.node = None
self.conn1 = None
self.conn2 = None

def add_nengo_objects(self, page):
with page.model:
output = self.obj.outputs[self.target][0]
if self.target.startswith('<'):
output = getattr(self.obj, self.target[1:-1])
else:
output = self.obj.outputs[self.target][0]
self.node = nengo.Node(self.gather_data,
size_in=self.vocab_out.dimensions,
size_out=self.vocab_out.dimensions)
Expand All @@ -59,15 +69,25 @@ def gather_data(self, t, x):
vocab = self.vocab_out
key_similarities = np.dot(vocab.vectors, x)
over_threshold = key_similarities > 0.01
matches = zip(key_similarities[over_threshold],
np.array(vocab.keys)[over_threshold])
if self.config.show_pairs:
self.vocab_out.include_pairs = True
pair_similarities = np.dot(vocab.vector_pairs, x)
over_threshold = pair_similarities > 0.01
pair_matches = zip(pair_similarities[over_threshold],
np.array(vocab.key_pairs)[over_threshold])
matches = itertools.chain(matches, pair_matches)
if isinstance(vocab, spa.Vocabulary):
matches = zip(key_similarities[over_threshold],
np.array(vocab.keys)[over_threshold])
if self.config.show_pairs:
self.vocab_out.include_pairs = True
pair_similarities = np.dot(vocab.vector_pairs, x)
over_threshold = pair_similarities > 0.01
pair_matches = zip(pair_similarities[over_threshold],
np.array(vocab.key_pairs)[over_threshold])
matches = itertools.chain(matches, pair_matches)
else:
matches = zip(key_similarities[over_threshold],
[k for i, k in enumerate(vocab) if over_threshold[i]])
if self.config.show_pairs:
pair_similarities = np.array([np.dot(vocab.parse(p).v, x) for p in pairs(vocab)])
over_threshold = pair_similarities > 0.01
pair_matches = zip(pair_similarities[over_threshold],
(k for i, k in enumerate(pairs(vocab)) if over_threshold[i]))
matches = itertools.chain(matches, pair_matches)

text = ';'.join(['%0.2f%s' % ( min(sim, 9.99), key) for sim, key in matches])

Expand Down
33 changes: 27 additions & 6 deletions nengo_gui/components/spa_plot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import collections

from nengo.spa.module import Module
import nengo.spa as spa
try:
import nengo_spa
except ImportError:
nengo_spa = None

from nengo_gui.components.component import Component

Expand All @@ -13,7 +17,11 @@ def __init__(self, obj, **kwargs):
self.obj = obj
self.data = collections.deque()
self.target = kwargs.get('args', 'default')
self.vocab_out = obj.outputs[self.target][1]
if self.target.startswith('<'):
target_obj = getattr(obj, self.target[1:-1])
self.vocab_out = obj.get_output_vocab(target_obj)
else:
self.vocab_out = obj.outputs[self.target][1]

def attach(self, page, config, uid):
super(SpaPlot, self).attach(page, config, uid)
Expand All @@ -31,8 +39,21 @@ def code_python_args(self, uids):
@staticmethod
def applicable_targets(obj):
targets = []
if isinstance(obj, Module):
for target_name, (obj, vocab) in obj.outputs.items():
if vocab is not None:
targets.append(target_name)
if (isinstance(obj, spa.module.Module) or
(nengo_spa is not None and isinstance(obj, nengo_spa.Network))):

if hasattr(obj, 'outputs'):
for target_name, (obj, vocab) in obj.outputs.items():
if vocab is not None:
targets.append(target_name)
elif hasattr(obj, 'output'):
# TODO: check for other outputs than obj.output
try:
v = obj.get_output_vocab(obj.output)
if v is not None:
targets.append('<output>')
except KeyError:
# Module has no output vocab
pass

return targets
67 changes: 52 additions & 15 deletions nengo_gui/components/spa_similarity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import numpy as np
import nengo
import nengo.spa as spa

try:
from nengo_spa.examine import pairs
except ImportError:
pairs = None

from nengo_gui.components.component import Component
from nengo_gui.components.spa_plot import SpaPlot
Expand All @@ -15,9 +21,13 @@ class SpaSimilarity(SpaPlot):
def __init__(self, obj, **kwargs):
super(SpaSimilarity, self).__init__(obj, **kwargs)

self.old_vocab_length = len(self.vocab_out.keys)
if isinstance(self.vocab_out, spa.Vocabulary):
self.old_vocab_length = len(self.vocab_out.keys)
self.labels = self.vocab_out.keys
else:
self.old_vocab_length = len(self.vocab_out)
self.labels = tuple(self.vocab_out.keys())
self.old_pairs_length = 0
self.labels = self.vocab_out.keys
self.previous_pairs = False

# Nengo objects for data collection
Expand All @@ -26,7 +36,10 @@ def __init__(self, obj, **kwargs):

def add_nengo_objects(self, page):
with page.model:
output = self.obj.outputs[self.target][0]
if self.target.startswith('<'):
output = getattr(self.obj, self.target[1:-1])
else:
output = self.obj.outputs[self.target][0]
self.node = nengo.Node(self.gather_data,
size_in=self.vocab_out.dimensions)
self.conn = nengo.Connection(output, self.node, synapse=0.01)
Expand All @@ -39,7 +52,11 @@ def remove_nengo_objects(self, page):
def gather_data(self, t, x):
vocab = self.vocab_out

if self.old_vocab_length != len(vocab.keys):
if isinstance(vocab, spa.Vocabulary):
length = len(vocab.keys)
else:
length = len(vocab)
if self.old_vocab_length != length:
self.update_legend(vocab)

# get the similarity and send it
Expand All @@ -50,7 +67,10 @@ def gather_data(self, t, x):

# briefly there can be no pairs, so catch the error
try:
pair_similarity = np.dot(vocab.vector_pairs, x)
if isinstance(vocab, spa.Vocabulary):
pair_similarity = np.dot(vocab.vector_pairs, x)
else:
pair_similarity = (np.dot(vocab.parse(p).v, x) for p in pairs(vocab))
simi_list += ['{:.2f}'.format(simi) for simi in pair_similarity]
except TypeError:
pass
Expand All @@ -62,14 +82,19 @@ def gather_data(self, t, x):
def update_legend(self, vocab):
# pass all the missing keys
legend_update = []
legend_update += (vocab.keys[self.old_vocab_length:])
self.old_vocab_length = len(vocab.keys)
if isinstance(vocab, spa.Vocabulary):
legend_update += (vocab.keys[self.old_vocab_length:])
self.old_vocab_length = len(vocab.keys)
else:
legend_update += (list(vocab.keys())[self.old_vocab_length:])
self.old_vocab_length = len(vocab)
# and all the missing pairs if we're showing pairs
if self.config.show_pairs:
# briefly there can be no pairs, so catch the error
try:
legend_update += vocab.key_pairs[self.old_pairs_length:]
self.old_pairs_length = len(vocab.key_pairs)
key_pairs = list(pairs(vocab))
legend_update += key_pairs[self.old_pairs_length:]
self.old_pairs_length = len(key_pairs)
except TypeError:
pass

Expand All @@ -89,12 +114,24 @@ def message(self, msg):
# Send the new labels
if self.config.show_pairs:
vocab.include_pairs = True
self.data.append(
'["reset_legend_and_data", "%s"]' % (
'","'.join(vocab.keys + vocab.key_pairs)))
# if we're starting to show pairs, track pair length
self.old_pairs_length = len(vocab.key_pairs)
if isinstance(vocab, spa.Vocabulary):
self.data.append(
'["reset_legend_and_data", "%s"]' % (
'","'.join(vocab.keys + vocab.key_pairs)))
# if we're starting to show pairs, track pair length
self.old_pairs_length = len(vocab.key_pairs)

else:
self.data.append(
'["reset_legend_and_data", "%s"]' % (
'","'.join(set(vocab.keys()) | pairs(vocab))))
# if we're starting to show pairs, track pair length
self.old_pairs_length = len(pairs(vocab))
else:
vocab.include_pairs = False
self.data.append('["reset_legend_and_data", "%s"]'
if isinstance(vocab, spa.Vocabulary):
self.data.append('["reset_legend_and_data", "%s"]'
% ('","'.join(vocab.keys)))
else:
self.data.append('["reset_legend_and_data", "%s"]'
% ('","'.join(vocab)))