Skip to content

Commit

Permalink
Support for 'no_conflicts', 'broadcast_equals' in concat.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Aug 22, 2019
1 parent 820463b commit 1b4858f
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from . import utils, dtypes
from .alignment import align
from .merge import broadcast_dimension_size
from .variable import IndexVariable, Variable, as_variable
from .variable import concat as concat_vars

Expand Down Expand Up @@ -145,7 +146,13 @@ def concat(
"`data_vars` and `coords` arguments"
)

if compat not in ["equals", "identical", "override", "no_conflicts"]:
if compat not in [
"equals",
"identical",
"override",
"no_conflicts",
"broadcast_equals",
]:
raise ValueError(
"compat=%r invalid: must be 'equals', 'identical or 'override'" % compat
)
Expand Down Expand Up @@ -186,7 +193,7 @@ def _calc_concat_dim_coord(dim):
return dim, coord


def _calc_concat_over(datasets, dim, data_vars, coords):
def _calc_concat_over(datasets, dim, data_vars, coords, compat):
"""
Determine which dataset variables need to be concatenated in the result,
and which can simply be taken from the first dataset.
Expand Down Expand Up @@ -225,14 +232,18 @@ def process_subset_opt(opt, subset):
for ds_rhs in datasets[1:]:
v_rhs = ds_rhs.variables[k].compute()
computed.append(v_rhs)
if not v_lhs.equals(v_rhs):
if not getattr(v_lhs, compat)(v_rhs):
concat_over.add(k)
equals[k] = False
# computed variables are not to be re-computed
# again in the future
for ds, v in zip(datasets[1:], computed):
ds.variables[k].data = v.data
break
elif compat == "broadcast_equals":
# (REMOVE): This is from merge.unique_variable
dim_lengths = broadcast_dimension_size(computed)
v_lhs = v_lhs.set_dims(dim_lengths)
else:
equals[k] = True

Expand Down Expand Up @@ -291,7 +302,7 @@ def _dataset_concat(
*datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value
)

concat_over, equals = _calc_concat_over(datasets, dim, data_vars, coords)
concat_over, equals = _calc_concat_over(datasets, dim, data_vars, coords, compat)

def insert_result_variable(k, v):
assert isinstance(v, Variable)
Expand Down Expand Up @@ -338,7 +349,7 @@ def insert_result_variable(k, v):
is_equal = equals[k]
except KeyError:
result_vars[k].load()
is_equal = v.equals(result_vars[k])
is_equal = getattr(v, compat)(result_vars[k])
if not is_equal:
raise ValueError(
"Variable '%s' is not equal across datasets. "
Expand Down

0 comments on commit 1b4858f

Please sign in to comment.