Skip to content

Commit

Permalink
Normalisation for RGB imshow
Browse files Browse the repository at this point in the history
  • Loading branch information
Zac-HD committed Jan 14, 2018
1 parent 502a988 commit 83f567c
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 3 deletions.
2 changes: 2 additions & 0 deletions doc/plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,8 @@ example, consider the original data in Kelvins rather than Celsius:
The Celsius data contain 0, so a diverging color map was used. The
Kelvins do not have 0, so the default color map was used.

.. _robust-plotting:

Robust
~~~~~~

Expand Down
1 change: 1 addition & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Enhancements
- Support for using `Zarr`_ as storage layer for xarray.
By `Ryan Abernathey <https://github.com/rabernat>`_.
- :func:`xarray.plot.imshow` now handles RGB and RGBA images.
Saturation can be adjusted with ``vmin`` and ``vmax``, or with ``robust=True``.
By `Zac Hatfield-Dodds <https://github.com/Zac-HD>`_.
- Experimental support for parsing ENVI metadata to coordinates and attributes
in :py:func:`xarray.open_rasterio`.
Expand Down
27 changes: 25 additions & 2 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import pandas as pd
from datetime import datetime

from .utils import (_determine_cmap_params, _infer_xy_labels, get_axis,
import_matplotlib_pyplot)
from .utils import (ROBUST_PERCENTILE, _determine_cmap_params,
_infer_xy_labels, get_axis, import_matplotlib_pyplot)
from .facetgrid import FacetGrid
from xarray.core.pycompat import basestring

Expand Down Expand Up @@ -449,10 +449,28 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
if imshow_rgb:
# Don't add a colorbar when showing an image with explicit colors
add_colorbar = False
# Calculate vmin and vmax automatically for `robust=True`
if robust:
if vmin is None:
vmin = np.nanpercentile(darray, ROBUST_PERCENTILE)
if vmax is None:
vmax = np.nanpercentile(darray, 100 - ROBUST_PERCENTILE)
robust = False
# Scale interval [vmin .. vmax] to [0 .. 1] and clip to bounds
if vmin is not None or vmax is not None:
vmin = vmin if vmin is not None else darray.min()
vmax = vmax if vmax is not None else darray.max()
darray = ((darray.astype('f8') - vmin) / (vmax - vmin))
vmin, vmax = None, None
# There's a cyclic dependency via DataArray, so we can't
# import xarray.ufuncs in global or outer scope.
import xarray.ufuncs as xu
darray = xu.minimum(xu.maximum(darray.astype('f4'), 0), 1)

# Handle facetgrids first
if row or col:
allargs = locals().copy()
allargs.pop('xu', None)
allargs.pop('imshow_rgb')
allargs.update(allargs.pop('kwargs'))

Expand Down Expand Up @@ -625,6 +643,11 @@ def imshow(x, y, z, ax, **kwargs):
dimension can be interpreted as RGB or RGBA color channels and
allows this dimension to be specified via the kwarg ``rgb=``.
Unlike matplotlib, Xarray can apply ``vmin`` and ``vmax`` to RGB or RGBA
data, by applying a single scaling factor and offset to all bands.
Passing ``robust=True`` infers ``vmin`` and ``vmax``
:ref:`in the usual way <robust-plotting>`.
.. note::
This function needs uniformly spaced coordinates to
properly label the axes. Call DataArray.plot() to check.
Expand Down
4 changes: 3 additions & 1 deletion xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from ..core.utils import is_scalar


ROBUST_PERCENTILE = 2.0


def _load_default_cmap(fname='default_colormap.csv'):
"""
Returns viridis color map
Expand Down Expand Up @@ -165,7 +168,6 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
cmap_params : dict
Use depends on the type of the plotting function
"""
ROBUST_PERCENTILE = 2.0
import matplotlib as mpl

calc_data = np.ravel(plot_data[~pd.isnull(plot_data)])
Expand Down
10 changes: 10 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,16 @@ def test_rgb_errors_bad_dim_sizes(self):
with pytest.raises(ValueError):
arr.plot.imshow(rgb='band')

def test_normalize_rgb_imshow(self):
for kwds in (
dict(vmin=-1), dict(vmax=2),
dict(vmin=-1, vmax=1), dict(vmin=0, vmax=0),
dict(vmin=0, robust=True), dict(vmax=-1, robust=True),
):
da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4))
arr = da.plot.imshow(**kwds).get_array()
assert 0 <= arr.min() <= arr.max() <= 1, kwds


class TestFacetGrid(PlotTestCase):

Expand Down

0 comments on commit 83f567c

Please sign in to comment.