From 7aed950777bf28eb7ad37a96b46f36c602eac5ec Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 2 Nov 2019 15:20:07 -0600 Subject: [PATCH] make plotting work with transposed nondim coords. --- doc/whats-new.rst | 3 +++ xarray/plot/plot.py | 14 ++++++++++---- xarray/tests/test_plot.py | 16 ++++++++++++++++ 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 47e2e58e988..30b167ebfdc 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -76,6 +76,9 @@ Bug fixes :py:meth:`xarray.core.groupby.DatasetGroupBy.reduce` when reducing over multiple dimensions. (:issue:`3402`). By `Deepak Cherian `_ +- Fix plotting with transposed 2D non-dimensional coordinates. (:issue:`3138`) + By `Deepak Cherian `_. + Documentation ~~~~~~~~~~~~~ - Fix leap year condition in example (http://xarray.pydata.org/en/stable/examples/monthly-means.html) by `Mickaƫl Lalande `_. diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index ca68f617144..c6bb231ce17 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -672,10 +672,16 @@ def newplotfunc( # check if we need to broadcast one dimension if xval.ndim < yval.ndim: - xval = np.broadcast_to(xval, yval.shape) + if xval.shape[0] == yval.shape[0]: + xval = np.broadcast_to(xval[:, np.newaxis], yval.shape) + else: + xval = np.broadcast_to(xval[np.newaxis, :], yval.shape) - if yval.ndim < xval.ndim: - yval = np.broadcast_to(yval, xval.shape) + elif yval.ndim < xval.ndim: + if yval.shape[0] == xval.shape[0]: + yval = np.broadcast_to(yval[:, np.newaxis], xval.shape) + else: + yval = np.broadcast_to(yval[np.newaxis, :], xval.shape) # May need to transpose for correct x, y labels # xlab may be the name of a coord, we have to check for dim names @@ -687,7 +693,7 @@ def newplotfunc( dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims) if dims != darray.dims: darray = darray.transpose(*dims, transpose_coords=True) - elif darray[xlab].dims[-1] == darray.dims[0]: + elif xval.shape[-1] == darray.shape[0]: darray = darray.transpose(transpose_coords=True) # Pass the data as a masked ndarray too diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 7deabd46eae..58210e6ffa0 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2145,3 +2145,19 @@ def test_yticks_kwarg(self, da): da.plot(yticks=np.arange(5)) expected = np.arange(5) assert np.all(plt.gca().get_yticks() == expected) + + +@requires_matplotlib +@pytest.mark.parametrize("plotfunc", ["pcolormesh", "contourf", "contour"]) +def test_plot_transposed_nondim_coord(plotfunc): + x = np.linspace(0, 10, 101) + h = np.linspace(3, 7, 101) + s = np.linspace(0, 1, 51) + z = s[:, np.newaxis] * h[np.newaxis, :] + da = xr.DataArray( + np.sin(x) * np.cos(z), + dims=["s", "x"], + coords={"x": x, "s": s, "z": (("s", "x"), z), "zt": (("x", "s"), z.T)}, + ) + getattr(da.plot, plotfunc)(x="x", y="zt") + getattr(da.plot, plotfunc)(x="zt", y="x")