Skip to content

Commit

Permalink
Allow RGB[A] dim for imshow to be in any order
Browse files Browse the repository at this point in the history
Includes new `rgb` keyword to tell imshow about that dimension, and much
error handling in inference.
  • Loading branch information
Zac-HD committed Dec 21, 2017
1 parent ad00933 commit 540300d
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 16 deletions.
5 changes: 3 additions & 2 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,9 @@ def map_dataarray(self, func, x, y, **kwargs):
func_kwargs.update({'add_colorbar': False, 'add_labels': False})

# Get x, y labels for the first subplot
x, y = _infer_xy_labels(darray=self.data.loc[self.name_dicts.flat[0]],
x=x, y=y, imshow=func.__name__ == 'imshow')
x, y = _infer_xy_labels(
darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y,
imshow=func.__name__ == 'imshow', rgb=kwargs.get('rgb', None))

for d, ax in zip(self.name_dicts.flat, self.axes.flat):
# None is the sentinel value
Expand Down
26 changes: 19 additions & 7 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,13 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
"Use colors keyword instead.",
DeprecationWarning, stacklevel=3)

xlab, ylab = _infer_xy_labels(darray=darray, x=x, y=y,
imshow=plotfunc.__name__ == 'imshow')
rgb = kwargs.pop('rgb', None)
xlab, ylab = _infer_xy_labels(
darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb)

# better to pass the ndarrays directly to plotting functions
xval = darray[xlab].values
yval = darray[ylab].values
zval = darray.to_masked_array(copy=False)

# check if we need to broadcast one dimension
if xval.ndim < yval.ndim:
Expand All @@ -485,8 +485,19 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,

# May need to transpose for correct x, y labels
# xlab may be the name of a coord, we have to check for dim names
if darray[xlab].dims[-1] == darray.dims[0]:
zval = zval.T
if imshow_rgb:
# For RGB[A] images, matplotlib requires the color dimension
# to be last. In Xarray the order should be unimportant, so
# we transpose to (y, x, color) to make this work.
yx_dims = (ylab, xlab)
dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims)
if dims != darray.dims:
darray = darray.transpose(*dims)
elif darray[xlab].dims[-1] == darray.dims[0]:
darray = darray.transpose()

# Pass the data as a masked ndarray too
zval = darray.to_masked_array(copy=False)

_ensure_plottable(xval, yval)

Expand Down Expand Up @@ -591,8 +602,9 @@ def imshow(x, y, z, ax, **kwargs):
Wraps :func:`matplotlib:matplotlib.pyplot.imshow`
While other plot methods require the DataArray to be strictly
two-dimensional, ``imshow`` also accepts a 3D array where the third
dimension can be interpreted as RGB or RGBA color channels.
two-dimensional, ``imshow`` also accepts a 3D array where some
dimension can be interpreted as RGB or RGBA color channels and
allows this dimension to be specified via the kwarg ``rgb=``.
In this case, ``robust=True`` will saturate the image in the
usual way, consistenly between all bands and facets.
Expand Down
65 changes: 58 additions & 7 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,21 +254,72 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
levels=levels, norm=norm)


def _infer_xy_labels(darray, x, y, imshow=False):
def _infer_xy_labels_3d(darray, x, y, rgb):
"""
Determine x and y labels for showing RGB images.
Attempts to infer which dimension is RGB/RGBA by size and order of dims.
"""
# Start by detecting and reporting invalid combinations of arguments
assert darray.ndim == 3
not_none = [a for a in (x, y, rgb) if a is not None]
if len(set(not_none)) < len(not_none):
raise ValueError('Dimensions passed as x, y, and rgb must be unique.')
for label in not_none:
if label not in darray.dims:
raise ValueError('%r is not a dimension' % (label,))

# Then calculate rgb dimension if certain and check validity
could_be_color = [label for label in darray.dims
if darray[label].size in (3, 4) and label not in (x, y)]
if rgb is None and not could_be_color:
raise ValueError(
'A 3-dimensional array was passed to imshow(), but there is no '
'dimension that could be color. At least one dimension must be '
'of size 3 (RGB) or 4 (RGBA), and not given as x or y.')
if rgb is None and len(could_be_color) == 1:
rgb = could_be_color[0]
if rgb is not None and darray[rgb].size not in (3, 4):
raise ValueError('Cannot interpret dim %r of size %s as RGB or RGBA.'
% (rgb, darray[rgb].size))

# If rgb dimension is still unknown, there must be two or three dimensions
# in could_be_color. We therefore warn, and use a heuristic to break ties.
if rgb is None:
assert len(could_be_color) in (2, 3)
if darray.dims[-1] in could_be_color:
rgb = darray.dims[-1]
warnings.warn(
'Several dimensions of this array could be colors. Xarray '
'will use the last dimension (%r) to match '
'matplotlib.pyplot.imshow. You can pass names of x, y, '
'and/or rgb dimensions to override this guess.' % rgb)
else:
rgb = darray.dims[0]
warnings.warn(
'%r has been selected as the color dimension, but %r would '
'also be valid. Pass names of x, y, and/or rgb dimensions to '
'override this guess.' % darray.dims[:2])
assert rgb is not None

# Finally, we pick out the red slice and delegate to the 2D version:
return _infer_xy_labels(darray.isel(**{rgb: 0}).squeeze(), x, y)


def _infer_xy_labels(darray, x, y, imshow=False, rgb=None):
"""
Determine x and y labels. For use in _plot2d
darray must be a 2 dimensional data array, or 3d for imshow only.
"""
if imshow and darray.ndim == 3:
return _infer_xy_labels_3d(darray, x, y, rgb)

if x is None and y is None:
if darray.ndim != 2:
if not imshow:
raise ValueError('DataArray must be 2d')
elif darray.ndim != 3 or darray.shape[2] not in (3, 4):
raise ValueError('DataArray for imshow must be 2d, MxNx3 for '
'RGB image, or MxNx4 for RGBA image.')
y, x, *_ = darray.dims
raise ValueError('DataArray must be 2d')
y, x = darray.dims
elif x is None:
if y not in darray.dims:
raise ValueError('y must be a dimension name if x is not supplied')
Expand Down
15 changes: 15 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,21 @@ def test_plot_rgb_faceted(self):
).plot.imshow(row='a', col='b')
self.assertEqual(0, len(find_possible_colorbars()))

def test_plot_rgba_image_transposed(self):
# We can handle the color axis being in any position
DataArray(
easy_array((4, 10, 15), start=0),
dims=['band', 'y', 'x'],
).plot.imshow()

def test_warns_ambigious_dim(self):
arr = DataArray(easy_array((3, 3, 3)), dims=['y', 'x', 'band'])
with pytest.warns(UserWarning):
arr.plot.imshow()
# but doesn't warn if dimensions specified
arr.plot.imshow(rgb='band')
arr.plot.imshow(x='x', y='y')


class TestFacetGrid(PlotTestCase):

Expand Down

0 comments on commit 540300d

Please sign in to comment.