Skip to content

Commit

Permalink
Validation, tests for rgb imshow
Browse files Browse the repository at this point in the history
  • Loading branch information
Zac-HD committed Jan 9, 2018
1 parent 0226d4c commit 9ddc03b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
9 changes: 8 additions & 1 deletion xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,8 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
# Convert byte-arrays to float for correct display in matplotlib
if darray.dtype == np.dtype('uint8'):
darray = darray / 256.0
# Manually stretch colors for robust cmap
# Manually stretch colors for robust cmap. We have to do this
# first so faceted plots are comparable between facets.
if robust:
flat = darray.values.ravel(order='K')
flat = flat[~np.isnan(flat)]
Expand Down Expand Up @@ -503,6 +504,12 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
xlab, ylab = _infer_xy_labels(
darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb)

if rgb is not None and plotfunc.__name__ != 'imshow':
raise ValueError('The "rgb" keyword is only valid for imshow()')
elif rgb is not None and not imshow_rgb:
raise ValueError('The "rgb" keyword is only valid for imshow()'
'with a three-dimensional array (per facet)')

# better to pass the ndarrays directly to plotting functions
xval = darray[xlab].values
yval = darray[ylab].values
Expand Down
22 changes: 22 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,11 @@ def test_can_plot_axis_size_one(self):
if self.plotfunc.__name__ not in ('contour', 'contourf'):
self.plotfunc(DataArray(np.ones((1, 1))))

def test_disallows_rgb_arg(self):
with pytest.raises(ValueError):
# Always invalid for most plots. Invalid for imshow with 2D data.
self.plotfunc(DataArray(np.ones((2, 2))), rgb='not None')

def test_viridis_cmap(self):
cmap_name = self.plotmethod(cmap='viridis').get_cmap().name
self.assertEqual('viridis', cmap_name)
Expand Down Expand Up @@ -1071,6 +1076,13 @@ def test_plot_rgb_image(self):
).plot.imshow()
self.assertEqual(0, len(find_possible_colorbars()))

def test_plot_rgb_image_explicit(self):
DataArray(
easy_array((10, 15, 3), start=0),
dims=['y', 'x', 'band'],
).plot.imshow(y='y', x='x', rgb='band')
self.assertEqual(0, len(find_possible_colorbars()))

def test_plot_rgb_faceted(self):
DataArray(
easy_array((2, 2, 10, 15, 3), start=0),
Expand All @@ -1093,6 +1105,16 @@ def test_warns_ambigious_dim(self):
arr.plot.imshow(rgb='band')
arr.plot.imshow(x='x', y='y')

def test_rgb_errors_too_many_dims(self):
arr = DataArray(easy_array((3, 3, 3, 3)), dims=['y', 'x', 'z', 'band'])
with pytest.raises(ValueError):
arr.plot.imshow(rgb='band')

def test_rgb_errors_bad_dim_sizes(self):
arr = DataArray(easy_array((5, 5, 5)), dims=['y', 'x', 'band'])
with pytest.raises(ValueError):
arr.plot.imshow(rgb='band')


class TestFacetGrid(PlotTestCase):

Expand Down

0 comments on commit 9ddc03b

Please sign in to comment.