Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fill missing data_vars during concat by reindexing #7400

Merged
merged 39 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
829a744
ENH: fill missing variables during concat by reindexing
kmuehlbauer Dec 22, 2022
871d2db
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 22, 2022
14bb779
FIX: use `Any` for type of `fill_value` as this seems consistent with…
kmuehlbauer Jan 4, 2023
6b49713
ENH: add tests
kmuehlbauer Jan 4, 2023
43862a1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2023
5dd3e78
typing
kmuehlbauer Jan 8, 2023
aed0f6e
typing
kmuehlbauer Jan 8, 2023
fa1aba5
typing
kmuehlbauer Jan 8, 2023
308c009
use None instead of False
kmuehlbauer Jan 8, 2023
58813e3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2023
a248d5b
concatenate variable in any case if variable has concat_dim
kmuehlbauer Jan 8, 2023
1a6ad5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2023
b9d0d02
add tests from @scottcha #3545
scottcha Jan 8, 2023
2db43b9
typing
kmuehlbauer Jan 8, 2023
e54d423
fix typing
kmuehlbauer Jan 8, 2023
7fe0ce2
fix tests with, finalize typing
kmuehlbauer Jan 9, 2023
ffd13f6
add whats-new.rst entry
kmuehlbauer Jan 9, 2023
146b740
Update xarray/tests/test_concat.py
kmuehlbauer Jan 9, 2023
3196e83
Update xarray/tests/test_concat.py
kmuehlbauer Jan 9, 2023
e53cfcc
add TODO, fix numpy.random.default_rng
kmuehlbauer Jan 9, 2023
c41419f
change np.random to use Generator
kmuehlbauer Jan 10, 2023
b2b0b18
move code for variable order into dedicated function, merge with _par…
kmuehlbauer Jan 10, 2023
bb0a8ae
fix comment
kmuehlbauer Jan 10, 2023
e266fe5
Use order from first dataset, append missing variables to the end
kmuehlbauer Jan 16, 2023
6120796
ensure fill_value is dict
kmuehlbauer Jan 16, 2023
5733d93
ensure fill_value in align
kmuehlbauer Jan 16, 2023
9891439
simplify combined_var, fix test
kmuehlbauer Jan 16, 2023
94b9ba9
revert fill_value for alignment.py
kmuehlbauer Jan 16, 2023
70be70f
derive variable order in order of appearance as suggested per review
kmuehlbauer Jan 17, 2023
c3eda8f
remove unneeded enumerate
kmuehlbauer Jan 17, 2023
70f38ab
Use alignment.reindex_variables instead.
dcherian Jan 17, 2023
4825c94
small cleanup
dcherian Jan 17, 2023
a71f633
Update doc/whats-new.rst
kmuehlbauer Jan 17, 2023
92e6108
adapt tests as per review request, fix ensure_common_dims
kmuehlbauer Jan 18, 2023
8610397
adapt tests as per review request
kmuehlbauer Jan 19, 2023
6cb163e
fix whats-new.rst
kmuehlbauer Jan 19, 2023
070c0fb
add whats-new.rst entry
kmuehlbauer Jan 9, 2023
de60890
Add additional test with scalar data_var
dcherian Jan 19, 2023
b14c8f6
remove erroneous content from whats-new.rst
kmuehlbauer Jan 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ Deprecations
Bug fixes
~~~~~~~~~

- :py:func:`xarray.concat` can now concatenate variables present in some datasets but
not others (:issue:`508`, :pull:`7400`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_ and `Scott Chamberlin <https://github.com/scottcha>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
71 changes: 55 additions & 16 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pandas as pd

from xarray.core import dtypes, utils
from xarray.core.alignment import align
from xarray.core.alignment import align, reindex_variables
from xarray.core.duck_array_ops import lazy_array_equiv
from xarray.core.indexes import Index, PandasIndex
from xarray.core.merge import (
Expand Down Expand Up @@ -378,7 +378,9 @@ def process_subset_opt(opt, subset):

elif opt == "all":
concat_over.update(
set(getattr(datasets[0], subset)) - set(datasets[0].dims)
set().union(
*list(set(getattr(d, subset)) - set(d.dims) for d in datasets)
)
kmuehlbauer marked this conversation as resolved.
Show resolved Hide resolved
)
elif opt == "minimal":
pass
Expand Down Expand Up @@ -406,19 +408,26 @@ def process_subset_opt(opt, subset):

# determine dimensional coordinate names and a dict mapping name to DataArray
def _parse_datasets(
datasets: Iterable[T_Dataset],
) -> tuple[dict[Hashable, Variable], dict[Hashable, int], set[Hashable], set[Hashable]]:

datasets: list[T_Dataset],
) -> tuple[
dict[Hashable, Variable],
dict[Hashable, int],
set[Hashable],
set[Hashable],
list[Hashable],
]:
dims: set[Hashable] = set()
all_coord_names: set[Hashable] = set()
data_vars: set[Hashable] = set() # list of data_vars
dim_coords: dict[Hashable, Variable] = {} # maps dim name to variable
dims_sizes: dict[Hashable, int] = {} # shared dimension sizes to expand variables
variables_order: dict[Hashable, Variable] = {} # variables in order of appearance

for ds in datasets:
dims_sizes.update(ds.dims)
all_coord_names.update(ds.coords)
data_vars.update(ds.data_vars)
variables_order.update(ds.variables)

# preserves ordering of dimensions
for dim in ds.dims:
Expand All @@ -429,7 +438,7 @@ def _parse_datasets(
dim_coords[dim] = ds.coords[dim].variable
dims = dims | set(ds.dims)

return dim_coords, dims_sizes, all_coord_names, data_vars
return dim_coords, dims_sizes, all_coord_names, data_vars, list(variables_order)


def _dataset_concat(
Expand All @@ -439,7 +448,7 @@ def _dataset_concat(
coords: str | list[str],
compat: CompatOptions,
positions: Iterable[Iterable[int]] | None,
fill_value: object = dtypes.NA,
fill_value: Any = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
) -> T_Dataset:
Expand Down Expand Up @@ -471,7 +480,9 @@ def _dataset_concat(
align(*datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value)
)

dim_coords, dims_sizes, coord_names, data_names = _parse_datasets(datasets)
dim_coords, dims_sizes, coord_names, data_names, vars_order = _parse_datasets(
datasets
)
dim_names = set(dim_coords)
unlabeled_dims = dim_names - coord_names

Expand Down Expand Up @@ -525,7 +536,7 @@ def _dataset_concat(

# we've already verified everything is consistent; now, calculate
# shared dimension sizes so we can expand the necessary variables
def ensure_common_dims(vars):
def ensure_common_dims(vars, concat_dim_lengths):
# ensure each variable with the given name shares the same
# dimensions and the same shape for all of them except along the
# concat dimension
Expand Down Expand Up @@ -553,16 +564,35 @@ def get_indexes(name):
data = var.set_dims(dim).values
yield PandasIndex(data, dim, coord_dtype=var.dtype)

# create concatenation index, needed for later reindexing
concat_index = list(range(sum(concat_dim_lengths)))

# stack up each variable and/or index to fill-out the dataset (in order)
# n.b. this loop preserves variable order, needed for groupby.
for name in datasets[0].variables:
for name in vars_order:
if name in concat_over and name not in result_indexes:
try:
vars = ensure_common_dims([ds[name].variable for ds in datasets])
except KeyError:
raise ValueError(f"{name!r} is not present in all datasets.")

# Try concatenate the indexes, concatenate the variables when no index
variables = []
variable_index = []
var_concat_dim_length = []
for i, ds in enumerate(datasets):
if name in ds.variables:
variables.append(ds[name].variable)
# add to variable index, needed for reindexing
var_idx = [
sum(concat_dim_lengths[:i]) + k
for k in range(concat_dim_lengths[i])
]
variable_index.extend(var_idx)
var_concat_dim_length.append(len(var_idx))
else:
# raise if coordinate not in all datasets
if name in coord_names:
kmuehlbauer marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"coordinate {name!r} not present in all datasets."
)
vars = ensure_common_dims(variables, var_concat_dim_length)

# Try to concatenate the indexes, concatenate the variables when no index
# is found on all datasets.
indexes: list[Index] = list(get_indexes(name))
if indexes:
Expand All @@ -589,6 +619,15 @@ def get_indexes(name):
combined_var = concat_vars(
vars, dim, positions, combine_attrs=combine_attrs
)
# reindex if variable is not present in all datasets
if len(variable_index) < len(concat_index):
combined_var = reindex_variables(
variables={name: combined_var},
dim_pos_indexers={
dim: pd.Index(variable_index).get_indexer(concat_index)
dcherian marked this conversation as resolved.
Show resolved Hide resolved
},
fill_value=fill_value,
)[name]
result_vars[name] = combined_var

elif name in result_vars:
Expand Down
Loading