Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalisation for RGB imshow #1819

Merged
merged 2 commits into from
Jan 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,8 @@ example, consider the original data in Kelvins rather than Celsius:
The Celsius data contain 0, so a diverging color map was used. The
Kelvins do not have 0, so the default color map was used.

.. _robust-plotting:

Robust
~~~~~~

Expand Down
1 change: 1 addition & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Enhancements
- Support for using `Zarr`_ as storage layer for xarray.
By `Ryan Abernathey <https://github.com/rabernat>`_.
- :func:`xarray.plot.imshow` now handles RGB and RGBA images.
Saturation can be adjusted with ``vmin`` and ``vmax``, or with ``robust=True``.
By `Zac Hatfield-Dodds <https://github.com/Zac-HD>`_.
- Experimental support for parsing ENVI metadata to coordinates and attributes
in :py:func:`xarray.open_rasterio`.
Expand Down
47 changes: 45 additions & 2 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import pandas as pd
from datetime import datetime

from .utils import (_determine_cmap_params, _infer_xy_labels, get_axis,
import_matplotlib_pyplot)
from .utils import (ROBUST_PERCENTILE, _determine_cmap_params,
_infer_xy_labels, get_axis, import_matplotlib_pyplot)
from .facetgrid import FacetGrid
from xarray.core.pycompat import basestring

Expand Down Expand Up @@ -326,6 +326,39 @@ def line(self, *args, **kwargs):
return line(self._da, *args, **kwargs)


def _rescale_imshow_rgb(darray, vmin, vmax, robust):
assert robust or vmin is not None or vmax is not None
# There's a cyclic dependency via DataArray, so we can't import from
# xarray.ufuncs in global scope.
from xarray.ufuncs import maximum, minimum
# Calculate vmin and vmax automatically for `robust=True`
if robust:
if vmax is None:
vmax = np.nanpercentile(darray, 100 - ROBUST_PERCENTILE)
if vmin is None:
vmin = np.nanpercentile(darray, ROBUST_PERCENTILE)
# If not robust and one bound is None, calculate the default other bound
# and check that an interval between them exists.
elif vmax is None:
vmax = 255 if np.issubdtype(darray.dtype, np.integer) else 1
if vmax < vmin:
raise ValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These error checks look great, thanks! Can you add a test that covers them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 😄

'vmin=%r is less than the default vmax (%r) - you must supply '
'a vmax > vmin in this case.' % (vmin, vmax))
elif vmin is None:
vmin = 0
if vmin > vmax:
raise ValueError(
'vmax=%r is less than the default vmin (0) - you must supply '
'a vmin < vmax in this case.' % vmax)
# Scale interval [vmin .. vmax] to [0 .. 1], with darray as 64-bit float
# to avoid precision loss, integer over/underflow, etc with extreme inputs.
# After scaling, downcast to 32-bit float. This substantially reduces
# memory usage after we hand `darray` off to matplotlib.
darray = ((darray.astype('f8') - vmin) / (vmax - vmin)).astype('f4')
return minimum(maximum(darray, 0), 1)


def _plot2d(plotfunc):
"""
Decorator for common 2d plotting logic
Expand Down Expand Up @@ -449,6 +482,11 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
if imshow_rgb:
# Don't add a colorbar when showing an image with explicit colors
add_colorbar = False
# Matplotlib does not support normalising RGB data, so do it here.
# See eg. https://github.com/matplotlib/matplotlib/pull/10220
if robust or vmax is not None or vmin is not None:
darray = _rescale_imshow_rgb(darray, vmin, vmax, robust)
vmin, vmax, robust = None, None, False

# Handle facetgrids first
if row or col:
Expand Down Expand Up @@ -625,6 +663,11 @@ def imshow(x, y, z, ax, **kwargs):
dimension can be interpreted as RGB or RGBA color channels and
allows this dimension to be specified via the kwarg ``rgb=``.

Unlike matplotlib, Xarray can apply ``vmin`` and ``vmax`` to RGB or RGBA
data, by applying a single scaling factor and offset to all bands.
Passing ``robust=True`` infers ``vmin`` and ``vmax``
:ref:`in the usual way <robust-plotting>`.

.. note::
This function needs uniformly spaced coordinates to
properly label the axes. Call DataArray.plot() to check.
Expand Down
4 changes: 3 additions & 1 deletion xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from ..core.utils import is_scalar


ROBUST_PERCENTILE = 2.0


def _load_default_cmap(fname='default_colormap.csv'):
"""
Returns viridis color map
Expand Down Expand Up @@ -165,7 +168,6 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
cmap_params : dict
Use depends on the type of the plotting function
"""
ROBUST_PERCENTILE = 2.0
import matplotlib as mpl

calc_data = np.ravel(plot_data[~pd.isnull(plot_data)])
Expand Down
20 changes: 20 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,26 @@ def test_rgb_errors_bad_dim_sizes(self):
with pytest.raises(ValueError):
arr.plot.imshow(rgb='band')

def test_normalize_rgb_imshow(self):
for kwds in (
dict(vmin=-1), dict(vmax=2),
dict(vmin=-1, vmax=1), dict(vmin=0, vmax=0),
dict(vmin=0, robust=True), dict(vmax=-1, robust=True),
):
da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4))
arr = da.plot.imshow(**kwds).get_array()
assert 0 <= arr.min() <= arr.max() <= 1, kwds

def test_normalize_rgb_one_arg_error(self):
da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4))
# If passed one bound that implies all out of range, error:
for kwds in [dict(vmax=-1), dict(vmin=2)]:
with pytest.raises(ValueError):
da.plot.imshow(**kwds)
# If passed two that's just moving the range, *not* an error:
for kwds in [dict(vmax=-1, vmin=-1.2), dict(vmin=2, vmax=2.1)]:
da.plot.imshow(**kwds)


class TestFacetGrid(PlotTestCase):
def setUp(self):
Expand Down