Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize dask array equality checks. #3453

Merged
merged 11 commits into from
Nov 5, 2019
16 changes: 16 additions & 0 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from . import dtypes, utils
from .alignment import align
from .duck_array_ops import lazy_array_equiv
from .merge import _VALID_COMPAT, unique_variable
from .variable import IndexVariable, Variable, as_variable
from .variable import concat as concat_vars
Expand Down Expand Up @@ -189,6 +190,21 @@ def process_subset_opt(opt, subset):
# all nonindexes that are not the same in each dataset
for k in getattr(datasets[0], subset):
if k not in concat_over:
equals[k] = None
variables = [ds.variables[k] for ds in datasets]
# first check without comparing values i.e. no computes
for var in variables[1:]:
equals[k] = getattr(variables[0], compat)(
var, equiv=lazy_array_equiv
)
if not equals[k]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same concern as below -- should this be checking against None explicitly, which I believe lazy_array_equiv can return?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is right. We are exiting early because equals[k] is not True i.e. either the shapes are not equal, or one (or both) of the arrays is a numpy array or the dask names are not equal

break

if equals[k] is not None:
if equals[k] is False:
concat_over.add(k)
continue

# Compare the variable of all datasets vs. the one
# of the first dataset. Perform the minimum amount of
# loads in order to avoid multiple loads from disk
Expand Down
41 changes: 41 additions & 0 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,49 @@ def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8):
arr2 = asarray(arr2)
if arr1.shape != arr2.shape:
return False
if (
dask_array
and isinstance(arr1, dask_array.Array)
and isinstance(arr2, dask_array.Array)
):
# GH3068
if arr1.name == arr2.name:
return True
return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all())


def lazy_array_equiv(arr1, arr2):
"""Like array_equal, but doesn't actually compare values
"""
arr1 = asarray(arr1)
arr2 = asarray(arr2)
if arr1.shape != arr2.shape:
return False
if (
dask_array
and isinstance(arr1, dask_array.Array)
and isinstance(arr2, dask_array.Array)
):
# GH3068
if arr1.name == arr2.name:
return True

dcherian marked this conversation as resolved.
Show resolved Hide resolved

def array_equiv(arr1, arr2):
"""Like np.array_equal, but also allows values to be NaN in both arrays
"""
arr1 = asarray(arr1)
arr2 = asarray(arr2)
if arr1.shape != arr2.shape:
return False
if (
dask_array
and isinstance(arr1, dask_array.Array)
and isinstance(arr2, dask_array.Array)
):
# GH3068
if arr1.name == arr2.name:
return True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking a little repetitive. Can you make a helper function for doing these checks?

Also note the similar logic (checking object identity) inside Variable.equals

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. We can just call lazy_array_equiv in all these before doing the actual value comparison. Then only the asarray cast is duplicated.

Re Variable.equals:

xarray/xarray/core/variable.py

Lines 1560 to 1576 in fb0cf7b

def equals(self, other, equiv=duck_array_ops.array_equiv):
"""True if two Variables have the same dimensions and values;
otherwise False.
Variables can still be equal (like pandas objects) if they have NaN
values in the same locations.
This method is necessary because `v1 == v2` for Variables
does element-wise comparisons (like numpy.ndarrays).
"""
other = getattr(other, "variable", other)
try:
return self.dims == other.dims and (
self._data is other._data or equiv(self.data, other.data)
)
except (TypeError, AttributeError):
return False

equiv=duck_array_ops.array_equiv by default which now calls duck_array_ops.lazy_array_equiv first so things are good?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I've added the identity check to lazy_array_equiv too

with warnings.catch_warnings():
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2))
Expand All @@ -205,6 +238,14 @@ def array_notnull_equiv(arr1, arr2):
arr2 = asarray(arr2)
if arr1.shape != arr2.shape:
return False
if (
dask_array
and isinstance(arr1, dask_array.Array)
and isinstance(arr2, dask_array.Array)
):
# GH3068
if arr1.name == arr2.name:
return True
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2)
Expand Down
17 changes: 13 additions & 4 deletions xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from . import dtypes, pdcompat
from .alignment import deep_align
from .duck_array_ops import lazy_array_equiv
from .utils import Frozen, dict_equiv
from .variable import Variable, as_variable, assert_unique_multiindex_level_names

Expand Down Expand Up @@ -123,16 +124,24 @@ def unique_variable(
combine_method = "fillna"

if equals is None:
out = out.compute()
# first check without comparing values i.e. no computes
for var in variables[1:]:
equals = getattr(out, compat)(var)
equals = getattr(out, compat)(var, equiv=lazy_array_equiv)
if not equals:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be if equals is not None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is right (as above) but the next one is wrong. I've fixed that and added a test.

break

# now compare values with minimum number of computes
if not equals:
out = out.compute()
for var in variables[1:]:
equals = getattr(out, compat)(var)
if not equals:
break

if not equals:
raise MergeError(
"conflicting values for variable {!r} on objects to be combined. "
"You can skip this check by specifying compat='override'.".format(name)
f"conflicting values for variable {name!r} on objects to be combined. "
"You can skip this check by specifying compat='override'."
)

if combine_method:
Expand Down
14 changes: 9 additions & 5 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,9 @@ def transpose(self, *dims) -> "Variable":
if len(dims) == 0:
dims = self.dims[::-1]
axes = self.get_axis_num(dims)
if len(dims) < 2: # no need to transpose if only one dimension
if len(dims) < 2 or dims == self.dims:
# no need to transpose if only one dimension
# or dims are in same order
return self.copy(deep=False)

data = as_indexable(self._data).transpose(axes)
Expand Down Expand Up @@ -1588,22 +1590,24 @@ def broadcast_equals(self, other, equiv=duck_array_ops.array_equiv):
return False
return self.equals(other, equiv=equiv)

def identical(self, other):
def identical(self, other, equiv=duck_array_ops.array_equiv):
"""Like equals, but also checks attributes.
"""
try:
return utils.dict_equiv(self.attrs, other.attrs) and self.equals(other)
return utils.dict_equiv(self.attrs, other.attrs) and self.equals(
other, equiv=equiv
)
except (TypeError, AttributeError):
return False

def no_conflicts(self, other):
def no_conflicts(self, other, equiv=duck_array_ops.array_notnull_equiv):
"""True if the intersection of two Variable's non-null data is
equal; otherwise false.

Variables can thus still be equal if there are locations where either,
or both, contain NaN values.
"""
return self.broadcast_equals(other, equiv=duck_array_ops.array_notnull_equiv)
return self.broadcast_equals(other, equiv=equiv)

def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
"""Compute the qth quantile of the data along the specified dimension.
Expand Down
40 changes: 40 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
assert_identical,
raises_regex,
)
from ..core.duck_array_ops import lazy_array_equiv

dask = pytest.importorskip("dask")
da = pytest.importorskip("dask.array")
Expand Down Expand Up @@ -1135,3 +1136,42 @@ def test_make_meta(map_ds):
for variable in map_ds.data_vars:
assert variable in meta.data_vars
assert meta.data_vars[variable].shape == (0,) * meta.data_vars[variable].ndim


def test_identical_coords_no_computes():
lons2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
a = xr.DataArray(
da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2}
)
b = xr.DataArray(
da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2}
)
with raise_if_dask_computes():
c = a + b
assert_identical(c, a)


def test_lazy_array_equiv():
lons1 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
lons2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
var1 = lons1.variable
var2 = lons2.variable
with raise_if_dask_computes():
lons1.equals(lons2)
with raise_if_dask_computes():
var1.equals(var2 / 2, equiv=lazy_array_equiv)
assert var1.equals(var2.compute(), equiv=lazy_array_equiv) is None
assert var1.compute().equals(var2.compute(), equiv=lazy_array_equiv) is None

with raise_if_dask_computes():
assert lons1.equals(lons1.transpose("y", "x"))

with raise_if_dask_computes():
for compat in [
"broadcast_equals",
"equals",
"override",
"identical",
"no_conflicts",
]:
xr.merge([lons1, lons2], compat=compat)