diff --git a/doc/plotting.rst b/doc/plotting.rst index cd081811b99..2b816a24563 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -305,6 +305,8 @@ example, consider the original data in Kelvins rather than Celsius: The Celsius data contain 0, so a diverging color map was used. The Kelvins do not have 0, so the default color map was used. +.. _robust-plotting: + Robust ~~~~~~ diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 7bfe5991b78..12d3b910ca6 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -41,6 +41,7 @@ Enhancements - Support for using `Zarr`_ as storage layer for xarray. By `Ryan Abernathey `_. - :func:`xarray.plot.imshow` now handles RGB and RGBA images. + Saturation can be adjusted with ``vmin`` and ``vmax``, or with ``robust=True``. By `Zac Hatfield-Dodds `_. - Experimental support for parsing ENVI metadata to coordinates and attributes in :py:func:`xarray.open_rasterio`. diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 2cc39241556..0423c0967d1 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -15,8 +15,8 @@ import pandas as pd from datetime import datetime -from .utils import (_determine_cmap_params, _infer_xy_labels, get_axis, - import_matplotlib_pyplot) +from .utils import (ROBUST_PERCENTILE, _determine_cmap_params, + _infer_xy_labels, get_axis, import_matplotlib_pyplot) from .facetgrid import FacetGrid from xarray.core.pycompat import basestring @@ -449,10 +449,28 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, if imshow_rgb: # Don't add a colorbar when showing an image with explicit colors add_colorbar = False + # Calculate vmin and vmax automatically for `robust=True` + if robust: + if vmin is None: + vmin = np.nanpercentile(darray, ROBUST_PERCENTILE) + if vmax is None: + vmax = np.nanpercentile(darray, 100 - ROBUST_PERCENTILE) + robust = False + # Scale interval [vmin .. vmax] to [0 .. 1] and clip to bounds + if vmin is not None or vmax is not None: + vmin = vmin if vmin is not None else darray.min() + vmax = vmax if vmax is not None else darray.max() + darray = ((darray.astype('f8') - vmin) / (vmax - vmin)) + vmin, vmax = None, None + # There's a cyclic dependency via DataArray, so we can't + # import xarray.ufuncs in global or outer scope. + import xarray.ufuncs as xu + darray = xu.minimum(xu.maximum(darray.astype('f4'), 0), 1) # Handle facetgrids first if row or col: allargs = locals().copy() + allargs.pop('xu', None) allargs.pop('imshow_rgb') allargs.update(allargs.pop('kwargs')) @@ -625,6 +643,11 @@ def imshow(x, y, z, ax, **kwargs): dimension can be interpreted as RGB or RGBA color channels and allows this dimension to be specified via the kwarg ``rgb=``. + Unlike matplotlib, Xarray can apply ``vmin`` and ``vmax`` to RGB or RGBA + data, by applying a single scaling factor and offset to all bands. + Passing ``robust=True`` infers ``vmin`` and ``vmax`` + :ref:`in the usual way `. + .. note:: This function needs uniformly spaced coordinates to properly label the axes. Call DataArray.plot() to check. diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index abd62df2296..2c9fbb91aaa 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -11,6 +11,9 @@ from ..core.utils import is_scalar +ROBUST_PERCENTILE = 2.0 + + def _load_default_cmap(fname='default_colormap.csv'): """ Returns viridis color map @@ -165,7 +168,6 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, cmap_params : dict Use depends on the type of the plotting function """ - ROBUST_PERCENTILE = 2.0 import matplotlib as mpl calc_data = np.ravel(plot_data[~pd.isnull(plot_data)]) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 4615c59884a..55fba587e27 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1115,6 +1115,16 @@ def test_rgb_errors_bad_dim_sizes(self): with pytest.raises(ValueError): arr.plot.imshow(rgb='band') + def test_normalize_rgb_imshow(self): + for kwds in ( + dict(vmin=-1), dict(vmax=2), + dict(vmin=-1, vmax=1), dict(vmin=0, vmax=0), + dict(vmin=0, robust=True), dict(vmax=-1, robust=True), + ): + da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4)) + arr = da.plot.imshow(**kwds).get_array() + assert 0 <= arr.min() <= arr.max() <= 1, kwds + class TestFacetGrid(PlotTestCase):