diff --git a/doc/gallery/plot_rasterio.py b/doc/gallery/plot_rasterio.py index b42970db970..2ec58b884eb 100644 --- a/doc/gallery/plot_rasterio.py +++ b/doc/gallery/plot_rasterio.py @@ -44,13 +44,10 @@ da.coords['lon'] = (('y', 'x'), lon) da.coords['lat'] = (('y', 'x'), lat) -# Compute a greyscale out of the rgb image -greyscale = da.mean(dim='band') - # Plot on a map ax = plt.subplot(projection=ccrs.PlateCarree()) -greyscale.plot(ax=ax, x='lon', y='lat', transform=ccrs.PlateCarree(), - cmap='Greys_r', add_colorbar=False) +da.plot.imshow(ax=ax, x='lon', y='lat', rgb='band', + transform=ccrs.PlateCarree()) ax.coastlines('10m', color='r') plt.show() diff --git a/doc/whats-new.rst b/doc/whats-new.rst index af53c54dec7..f19908b02b8 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -33,6 +33,8 @@ Enhancements By `Joe Hamman `_. - Support for using `Zarr`_ as storage layer for xarray. By `Ryan Abernathey `_. +- :func:`xarray.plot.imshow` now handles RGB and RGBA images. + By `Zac Hatfield-Dodds `_. - Experimental support for parsing ENVI metadata to coordinates and attributes in :py:func:`xarray.open_rasterio`. By `Matti Eskelinen `_. diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 8e5ec80d6e6..badd44b25db 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -239,8 +239,9 @@ def map_dataarray(self, func, x, y, **kwargs): func_kwargs.update({'add_colorbar': False, 'add_labels': False}) # Get x, y labels for the first subplot - x, y = _infer_xy_labels(darray=self.data.loc[self.name_dicts.flat[0]], - x=x, y=y) + x, y = _infer_xy_labels( + darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y, + imshow=func.__name__ == 'imshow', rgb=kwargs.get('rgb', None)) for d, ax in zip(self.name_dicts.flat, self.axes.flat): # None is the sentinel value diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 19c70961c95..2cc39241556 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -443,10 +443,17 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, # Decide on a default for the colorbar before facetgrids if add_colorbar is None: add_colorbar = plotfunc.__name__ != 'contour' + imshow_rgb = ( + plotfunc.__name__ == 'imshow' and + darray.ndim == (3 + (row is not None) + (col is not None))) + if imshow_rgb: + # Don't add a colorbar when showing an image with explicit colors + add_colorbar = False # Handle facetgrids first if row or col: allargs = locals().copy() + allargs.pop('imshow_rgb') allargs.update(allargs.pop('kwargs')) # Need the decorated plotting function @@ -470,12 +477,19 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, "Use colors keyword instead.", DeprecationWarning, stacklevel=3) - xlab, ylab = _infer_xy_labels(darray=darray, x=x, y=y) + rgb = kwargs.pop('rgb', None) + xlab, ylab = _infer_xy_labels( + darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb) + + if rgb is not None and plotfunc.__name__ != 'imshow': + raise ValueError('The "rgb" keyword is only valid for imshow()') + elif rgb is not None and not imshow_rgb: + raise ValueError('The "rgb" keyword is only valid for imshow()' + 'with a three-dimensional array (per facet)') # better to pass the ndarrays directly to plotting functions xval = darray[xlab].values yval = darray[ylab].values - zval = darray.to_masked_array(copy=False) # check if we need to broadcast one dimension if xval.ndim < yval.ndim: @@ -486,8 +500,19 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, # May need to transpose for correct x, y labels # xlab may be the name of a coord, we have to check for dim names - if darray[xlab].dims[-1] == darray.dims[0]: - zval = zval.T + if imshow_rgb: + # For RGB[A] images, matplotlib requires the color dimension + # to be last. In Xarray the order should be unimportant, so + # we transpose to (y, x, color) to make this work. + yx_dims = (ylab, xlab) + dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims) + if dims != darray.dims: + darray = darray.transpose(*dims) + elif darray[xlab].dims[-1] == darray.dims[0]: + darray = darray.transpose() + + # Pass the data as a masked ndarray too + zval = darray.to_masked_array(copy=False) _ensure_plottable(xval, yval) @@ -595,6 +620,11 @@ def imshow(x, y, z, ax, **kwargs): Wraps :func:`matplotlib:matplotlib.pyplot.imshow` + While other plot methods require the DataArray to be strictly + two-dimensional, ``imshow`` also accepts a 3D array where some + dimension can be interpreted as RGB or RGBA color channels and + allows this dimension to be specified via the kwarg ``rgb=``. + .. note:: This function needs uniformly spaced coordinates to properly label the axes. Call DataArray.plot() to check. @@ -632,6 +662,15 @@ def imshow(x, y, z, ax, **kwargs): # Allow user to override these defaults defaults.update(kwargs) + if z.ndim == 3: + # matplotlib imshow uses black for missing data, but Xarray makes + # missing data transparent. We therefore add an alpha channel if + # there isn't one, and set it to transparent where data is masked. + if z.shape[-1] == 3: + z = np.ma.concatenate((z, np.ma.ones(z.shape[:2] + (1,))), 2) + z = z.copy() + z[np.any(z.mask, axis=-1), -1] = 0 + primitive = ax.imshow(z, **defaults) return primitive diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 3d2a633c7dc..abd62df2296 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -258,12 +258,65 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, levels=levels, norm=norm) -def _infer_xy_labels(darray, x, y): +def _infer_xy_labels_3d(darray, x, y, rgb): + """ + Determine x and y labels for showing RGB images. + + Attempts to infer which dimension is RGB/RGBA by size and order of dims. + + """ + assert rgb is None or rgb != x + assert rgb is None or rgb != y + # Start by detecting and reporting invalid combinations of arguments + assert darray.ndim == 3 + not_none = [a for a in (x, y, rgb) if a is not None] + if len(set(not_none)) < len(not_none): + raise ValueError( + 'Dimension names must be None or unique strings, but imshow was ' + 'passed x=%r, y=%r, and rgb=%r.' % (x, y, rgb)) + for label in not_none: + if label not in darray.dims: + raise ValueError('%r is not a dimension' % (label,)) + + # Then calculate rgb dimension if certain and check validity + could_be_color = [label for label in darray.dims + if darray[label].size in (3, 4) and label not in (x, y)] + if rgb is None and not could_be_color: + raise ValueError( + 'A 3-dimensional array was passed to imshow(), but there is no ' + 'dimension that could be color. At least one dimension must be ' + 'of size 3 (RGB) or 4 (RGBA), and not given as x or y.') + if rgb is None and len(could_be_color) == 1: + rgb = could_be_color[0] + if rgb is not None and darray[rgb].size not in (3, 4): + raise ValueError('Cannot interpret dim %r of size %s as RGB or RGBA.' + % (rgb, darray[rgb].size)) + + # If rgb dimension is still unknown, there must be two or three dimensions + # in could_be_color. We therefore warn, and use a heuristic to break ties. + if rgb is None: + assert len(could_be_color) in (2, 3) + rgb = could_be_color[-1] + warnings.warn( + 'Several dimensions of this array could be colors. Xarray ' + 'will use the last possible dimension (%r) to match ' + 'matplotlib.pyplot.imshow. You can pass names of x, y, ' + 'and/or rgb dimensions to override this guess.' % rgb) + assert rgb is not None + + # Finally, we pick out the red slice and delegate to the 2D version: + return _infer_xy_labels(darray.isel(**{rgb: 0}).squeeze(), x, y) + + +def _infer_xy_labels(darray, x, y, imshow=False, rgb=None): """ Determine x and y labels. For use in _plot2d - darray must be a 2 dimensional data array. + darray must be a 2 dimensional data array, or 3d for imshow only. """ + assert x is None or x != y + if imshow and darray.ndim == 3: + return _infer_xy_labels_3d(darray, x, y, rgb) if x is None and y is None: if darray.ndim != 2: diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 50a5c47c6bd..4615c59884a 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -619,6 +619,8 @@ def test_1d_raises_valueerror(self): def test_3d_raises_valueerror(self): a = DataArray(easy_array((2, 3, 4))) + if self.plotfunc.__name__ == 'imshow': + pytest.skip() with raises_regex(ValueError, r'DataArray must be 2d'): self.plotfunc(a) @@ -670,6 +672,11 @@ def test_can_plot_axis_size_one(self): if self.plotfunc.__name__ not in ('contour', 'contourf'): self.plotfunc(DataArray(np.ones((1, 1)))) + def test_disallows_rgb_arg(self): + with pytest.raises(ValueError): + # Always invalid for most plots. Invalid for imshow with 2D data. + self.plotfunc(DataArray(np.ones((2, 2))), rgb='not None') + def test_viridis_cmap(self): cmap_name = self.plotmethod(cmap='viridis').get_cmap().name self.assertEqual('viridis', cmap_name) @@ -1062,6 +1069,52 @@ def test_2d_coord_names(self): with raises_regex(ValueError, 'requires 1D coordinates'): self.plotmethod(x='x2d', y='y2d') + def test_plot_rgb_image(self): + DataArray( + easy_array((10, 15, 3), start=0), + dims=['y', 'x', 'band'], + ).plot.imshow() + self.assertEqual(0, len(find_possible_colorbars())) + + def test_plot_rgb_image_explicit(self): + DataArray( + easy_array((10, 15, 3), start=0), + dims=['y', 'x', 'band'], + ).plot.imshow(y='y', x='x', rgb='band') + self.assertEqual(0, len(find_possible_colorbars())) + + def test_plot_rgb_faceted(self): + DataArray( + easy_array((2, 2, 10, 15, 3), start=0), + dims=['a', 'b', 'y', 'x', 'band'], + ).plot.imshow(row='a', col='b') + self.assertEqual(0, len(find_possible_colorbars())) + + def test_plot_rgba_image_transposed(self): + # We can handle the color axis being in any position + DataArray( + easy_array((4, 10, 15), start=0), + dims=['band', 'y', 'x'], + ).plot.imshow() + + def test_warns_ambigious_dim(self): + arr = DataArray(easy_array((3, 3, 3)), dims=['y', 'x', 'band']) + with pytest.warns(UserWarning): + arr.plot.imshow() + # but doesn't warn if dimensions specified + arr.plot.imshow(rgb='band') + arr.plot.imshow(x='x', y='y') + + def test_rgb_errors_too_many_dims(self): + arr = DataArray(easy_array((3, 3, 3, 3)), dims=['y', 'x', 'z', 'band']) + with pytest.raises(ValueError): + arr.plot.imshow(rgb='band') + + def test_rgb_errors_bad_dim_sizes(self): + arr = DataArray(easy_array((5, 5, 5)), dims=['y', 'x', 'band']) + with pytest.raises(ValueError): + arr.plot.imshow(rgb='band') + class TestFacetGrid(PlotTestCase):