Skip to content

Commit

Permalink
FIX: Simplifications
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Apr 27, 2019
1 parent 6855ed1 commit e7788e5
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 102 deletions.
19 changes: 11 additions & 8 deletions examples/inverse/plot_lcmv_beamformer_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,23 @@

clim = dict(kind='value', pos_lims=[0.3, 0.6, 0.9])
stc.plot(src=forward['src'], subject='sample', subjects_dir=subjects_dir,
clim=clim, verbose=True)
clim=clim)

###############################################################################
# Now let's:
# - visualize the activity on a "glass brain",
# - show absolute values, and
# - morph data to ``'fsaverage'`` instead of `sample`

morph = mne.compute_source_morph(
forward['src'], 'sample', 'fsaverage', verbose=True)
# XXX The morph is still wrong. And even without morph it's wrong because it's
# not in MNI space.

# morph = mne.compute_source_morph(
# forward['src'], 'sample', 'fsaverage', subjects_dir=subjects_dir,
# verbose='debug')
clim = dict(kind='value', lims=[0.3, 0.6, 0.9])
# XXX there is a bug with the max picking here, it's in a different place
# (time course looks correct, but img does not)...
# Also, clicking is buggy. This is in both glass_brain and stat_map modes.
abs(stc).crop(0.08, 0.12).plot(
morph, subjects_dir=subjects_dir,
mode='glass_brain', clim=clim, verbose=True)
# src=morph,
src=forward['src'],
subject='sample', subjects_dir=subjects_dir,
mode='glass_brain', clim=clim)
189 changes: 96 additions & 93 deletions mne/viz/_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from ..surface import (get_meg_helmet_surf, read_surface,
transform_surface_to, _project_onto_surface,
mesh_edges, _reorder_ccw,
mesh_edges, _reorder_ccw, _compute_nearest,
_complete_sphere_surf)
from ..transforms import (read_trans, _find_trans, apply_trans, rot_to_quat,
combine_transforms, _get_trans, _ensure_trans,
Expand All @@ -46,6 +46,7 @@
read_bem_surfaces)


verbose_dec = verbose
FIDUCIAL_ORDER = (FIFF.FIFFV_POINT_LPA, FIFF.FIFFV_POINT_NASION,
FIFF.FIFFV_POINT_RPA)

Expand Down Expand Up @@ -1742,6 +1743,16 @@ def _glass_brain_crosshairs(params, x, y, z):
ax.axhline(b, color='0.75')


def _cut_coords_to_ijk(cut_coords, img):
ijk = apply_trans(linalg.inv(img.affine), cut_coords)
ijk = np.clip(np.round(ijk).astype(int), 0, np.array(img.shape[:3]) - 1)
return ijk


def _ijk_to_cut_coords(ijk, img):
return apply_trans(img.affine, ijk)


@verbose
def plot_volume_source_estimates(stc, src, subject=None, subjects_dir=None,
mode='stat_map', bg_img=None, colorbar=True,
Expand Down Expand Up @@ -1859,62 +1870,60 @@ def plot_volume_source_estimates(stc, src, subject=None, subjects_dir=None,
img = stc.as_volume(src, mri_resolution=False)
del src

def _cut_coords_to_idx(cut_coords, img_shape, strict=True):
def _cut_coords_to_idx(cut_coords, img):
"""Convert voxel coordinates to index in stc.data."""
# XXX: check lines below
cut_coords = apply_trans(linalg.inv(img.affine), cut_coords)
cut_coords = np.array([int(round(c)) for c in cut_coords])

# the affine transformation can sometimes lead to corner
# cases near the edges?
cut_coords = np.clip(cut_coords, 0, np.array(img_shape) - 1)
loc_idx = np.ravel_multi_index(
cut_coords, img_shape, order='F')
dist_vertices = [abs(v - loc_idx) for v in stc.vertices]
nearest_idx = int(round(np.argmin(dist_vertices)))
if dist_vertices[nearest_idx] == 0 or not strict:
return nearest_idx
else:
return None

def _get_cut_coords_stat_map(event, params):
ijk = _cut_coords_to_ijk(cut_coords, img)
del cut_coords
logger.debug(' Affine remapped cut coords to [%d, %d, %d] idx'
% tuple(ijk))
dist_vertices = np.array(
np.unravel_index(stc.vertices, img.shape[:3], order='F')).T
# XXX this assumes zooms are uniform, should probably mult by zooms...
loc_idx, dist = _compute_nearest(dist_vertices, ijk[np.newaxis],
return_dists=True)
loc_idx, dist = loc_idx[0], dist[0]
logger.debug(' Using vertex %d at a distance of %d voxels'
% (stc.vertices[loc_idx], dist))
return loc_idx

ax_name = dict(x='X (saggital)', y='Y (coronal)', z='Z (axial)')

def _click_to_cut_coords(event, params):
"""Get voxel coordinates from mouse click."""
if event.inaxes is params['ax_x']:
cut_coords = (params['ax_z'].lines[0].get_xdata()[0],
event.xdata, event.ydata)
ax = 'x'
x = params['ax_z'].lines[0].get_xdata()[0]
y, z = event.xdata, event.ydata
elif event.inaxes is params['ax_y']:
cut_coords = (event.xdata,
params['ax_x'].lines[0].get_xdata()[0],
event.ydata)
ax = 'y'
y = params['ax_x'].lines[0].get_xdata()[0]
x, z = event.xdata, event.ydata
elif event.inaxes is params['ax_z']:
ax = 'z'
x, y = event.xdata, event.ydata
z = params['ax_x'].lines[1].get_ydata()[0]
else:
cut_coords = (event.xdata, event.ydata,
params['ax_x'].lines[1].get_ydata()[0])
return cut_coords
logger.debug(' Click outside axes')
return None
cut_coords = np.array((x, y, z))

if params['mode'] == 'glass_brain': # find idx for MIP
img_data = np.abs(params['img_idx'].get_data())
ijk = _cut_coords_to_ijk(cut_coords, params['img_idx'])
if ax == 'x':
ijk[0] = np.argmax(img_data[:, ijk[1], ijk[2]])
logger.debug(' MIP: X = %d idx' % (ijk[0],))
elif ax == 'y':
ijk[1] = np.argmax(img_data[ijk[0], :, ijk[2]])
logger.debug(' MIP: Y = %d idx' % (ijk[1],))
else:
ijk[2] = np.argmax(img_data[ijk[0], ijk[1], :])
logger.debug(' MIP: Z = %d idx' % (ijk[2],))
cut_coords = _ijk_to_cut_coords(ijk, params['img_idx'])

def _get_cut_coords_glass_brain(event, params):
"""Get voxel coordinates with max intensity projection."""
img_data = np.abs(params['img_idx_resampled'].get_data())
shape = img_data.shape
if event.inaxes is params['ax_x']:
y, z = int(round(event.xdata)), int(round(event.ydata))
x = np.argmax(img_data[:, y + shape[1] // 2, z + shape[2] // 2])
x -= shape[0] // 2
elif event.inaxes is params['ax_y']:
x, z = int(round(event.xdata)), int(round(event.ydata))
y = np.argmax(img_data[x + shape[0] // 2, :, z + shape[2] // 2])
y -= shape[1] // 2
else:
x, y = int(round(event.xdata)), int(round(event.ydata))
z = np.argmax(img_data[x + shape[0] // 2, y + shape[1] // 2, :])
z -= shape[2] // 2
return (x, y, z)

def _resample(event, params):
"""Precompute the resampling as the mouse leaves the time axis."""
if event.inaxes is params['ax_time'] and mode == 'glass_brain':
img_resampled = resample_to_img(params['img_idx'],
params['bg_img'])
params.update({'img_idx_resampled': img_resampled})
logger.debug(' Cut coords for %s: (%0.1f, %0.1f, %0.1f) mm'
% ((ax_name[ax],) + tuple(cut_coords)))
return cut_coords

def _press(event, params):
"""Manage keypress on the plot."""
Expand All @@ -1934,7 +1943,7 @@ def _press(event, params):
params['fig'].canvas.draw()

def _update_timeslice(idx, params):
cut_coords = (0, 0, 0)
cut_coords = (0, 0, 0) # XXX WHY?
ax_x, ax_y, ax_z = params['ax_x'], params['ax_y'], params['ax_z']
plot_map_callback = params['plot_func']
if mode == 'stat_map':
Expand All @@ -1951,7 +1960,8 @@ def _update_timeslice(idx, params):
params['img_idx'], title='',
cut_coords=cut_coords)

def _onclick(event, params):
@verbose_dec
def _onclick(event, params, verbose=None):
"""Manage clicks on the plot."""
ax_x, ax_y, ax_z = params['ax_x'], params['ax_y'], params['ax_z']
plot_map_callback = params['plot_func']
Expand All @@ -1960,22 +1970,20 @@ def _onclick(event, params):
params['lx'].set_xdata(event.xdata)
_update_timeslice(idx, params)

if event.inaxes in [ax_x, ax_y, ax_z]:
if mode == 'stat_map':
cut_coords = _get_cut_coords_stat_map(event, params)
elif mode == 'glass_brain':
cut_coords = _get_cut_coords_glass_brain(event, params)

ax_x.clear()
ax_y.clear()
ax_z.clear()
plot_map_callback(params['img_idx'], title='',
cut_coords=cut_coords)
loc_idx = _cut_coords_to_idx(cut_coords, params['img_idx'].shape)
if loc_idx is not None:
ax_time.lines[0].set_ydata(stc.data[loc_idx].T)
else:
ax_time.lines[0].set_ydata(0.)
cut_coords = _click_to_cut_coords(event, params)
if cut_coords is None:
return # not in any axes

ax_x.clear()
ax_y.clear()
ax_z.clear()
plot_map_callback(params['img_idx'], title='',
cut_coords=cut_coords)
loc_idx = _cut_coords_to_idx(cut_coords, params['img_idx'])
if loc_idx is not None:
ax_time.lines[0].set_ydata(stc.data[loc_idx].T)
else:
ax_time.lines[0].set_ydata(0.)
params['fig'].canvas.draw()

if bg_img is None:
Expand All @@ -1985,47 +1993,47 @@ def _onclick(event, params):
t1_fname = op.join(subjects_dir, subject, 'mri', 'T1.mgz')
bg_img = nib.load(t1_fname)

bg_img_param = bg_img
if mode == 'glass_brain':
bg_img_param = None

bg_img_param = None if mode == 'glass_brain' else bg_img
if initial_time is None:
time_sl = slice(0, None)
else:
initial_time = float(initial_time)
initial_time = np.argmin(np.abs(stc.times - initial_time))
time_sl = slice(initial_time, initial_time + 1)
if initial_pos is None: # find max pos and (maybe) time
loc_idx, idx = np.unravel_index(np.abs(stc.data[:, time_sl]).argmax(),
stc.data[:, time_sl].shape)
idx += time_sl.start
loc_idx, time_idx = np.unravel_index(
np.abs(stc.data[:, time_sl]).argmax(), stc.data[:, time_sl].shape)
time_idx += time_sl.start
else: # position specified
initial_pos = np.array(initial_pos, float)
if initial_pos.shape != (3,):
raise ValueError('initial_pos must be float ndarray with shape '
'(3,), got shape %s' % (initial_pos.shape,))
loc_idx = _cut_coords_to_idx(1000 * initial_pos,
img.shape[:3], strict=False)
initial_pos *= 1000
loc_idx = _cut_coords_to_idx(initial_pos, img)
if initial_time is not None: # time also specified
idx = time_sl.start
time_idx = time_sl.start
else: # find the max
idx = np.argmax(np.abs(stc.data[loc_idx]))
img_idx = index_img(img, idx)
time_idx = np.argmax(np.abs(stc.data[loc_idx]))
img_idx = index_img(img, time_idx)
assert img_idx.shape == img.shape[:3]
del initial_time, initial_pos
xyz = np.unravel_index(stc.vertices[loc_idx], img_idx.shape, order='F')
cut_coords = apply_trans(img.affine, xyz)
ijk = np.unravel_index(stc.vertices[loc_idx], img_idx.shape, order='F')
cut_coords = _ijk_to_cut_coords(ijk, img_idx)
np.testing.assert_allclose(_cut_coords_to_ijk(cut_coords, img_idx), ijk)
logger.info('Showing t = %0.3f s (%0.1f, %0.1f, %0.1f) mm [%d, %d, %d] idx'
% ((stc.times[idx],) + tuple(cut_coords) + tuple(xyz)))
del xyz
' %d vertex'
% ((stc.times[time_idx],) + tuple(cut_coords) + tuple(ijk) +
(stc.vertices[loc_idx],)))
del ijk

# Plot initial figure
fig, (axes, ax_time) = plt.subplots(2)
ax_time.plot(stc.times, stc.data[loc_idx].T, color='k')
if stc.times[0] != stc.times[-1]:
ax_time.set(xlim=stc.times[[0, -1]])
ax_time.set(xlabel='Time (s)', ylabel='Activation')
lx = ax_time.axvline(stc.times[idx], color='g')
lx = ax_time.axvline(stc.times[time_idx], color='g')
axes.set(xticks=[], yticks=[])
fig.tight_layout()

Expand Down Expand Up @@ -2084,22 +2092,17 @@ def plot_and_correct(*args, **kwargs):
_glass_brain_crosshairs(params, *kwargs['cut_coords'])

params = dict(stc=stc, ax_time=ax_time, plot_func=plot_and_correct,
img_idx=img_idx, fig=fig, bg_img=bg_img, lx=lx)
img_idx=img_idx, fig=fig, lx=lx, mode=mode)

plot_and_correct(stat_map_img=params['img_idx'], title='',
cut_coords=cut_coords)
if mode == 'glass_brain':
params.update(img_idx_resampled=resample_to_img(
params['img_idx'], params['bg_img']))

if show:
plt.show()
fig.canvas.mpl_connect('button_press_event',
partial(_onclick, params=params))
partial(_onclick, params=params, verbose=verbose))
fig.canvas.mpl_connect('key_press_event',
partial(_press, params=params))
fig.canvas.mpl_connect('axes_leave_event',
partial(_resample, params=params))

return fig

Expand Down
3 changes: 2 additions & 1 deletion mne/viz/tests/test_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,8 @@ def test_plot_volume_source_estimates_morph():
n_time = 2
data = np.random.RandomState(0).rand(n_verts, n_time)
stc = VolSourceEstimate(data, vertices, 1, 1)
morph = compute_source_morph(sample_src, 'sample', 'fsaverage', zooms=10)
morph = compute_source_morph(sample_src, 'sample', 'fsaverage', zooms=10,
subjects_dir=subjects_dir)
initial_pos = (-0.05, -0.01, -0.006)
with pytest.warns(None): # sometimes get scalars/index warning
with catch_logging() as log:
Expand Down

0 comments on commit e7788e5

Please sign in to comment.