diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index c6bb231ce17..3bb2b88372f 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -693,7 +693,7 @@ def newplotfunc( dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims) if dims != darray.dims: darray = darray.transpose(*dims, transpose_coords=True) - elif xval.shape[-1] == darray.shape[0]: + elif darray[xlab].dims[-1] == darray.dims[0]: darray = darray.transpose(transpose_coords=True) # Pass the data as a masked ndarray too diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index b824e1ae9b0..34189d1f0b0 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2167,4 +2167,4 @@ def test_plot_transposed_nondim_coord(plotfunc): def test_plot_transposes_properly(): da = xr.DataArray([np.sin(2 * np.pi / 10 * np.arange(10))] * 10, dims=("y", "x")) hdl = da.plot(x="x", y="y") - np.all(hdl.get_array() == da.to_masked_array().ravel()) + assert np.all(hdl.get_array() == da.to_masked_array().ravel())