Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to T_DataArray and T_Dataset in concat #6784

Merged
merged 12 commits into from
Jul 18, 2022
39 changes: 20 additions & 19 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Hashable, Iterable, overload
from typing import TYPE_CHECKING, Any, Hashable, Iterable, cast, overload

import pandas as pd

Expand All @@ -14,42 +14,41 @@
merge_attrs,
merge_collected,
)
from .types import T_DataArray, T_Dataset
from .variable import Variable
from .variable import concat as concat_vars

if TYPE_CHECKING:
from .dataarray import DataArray
from .dataset import Dataset
from .types import CombineAttrsOptions, CompatOptions, ConcatOptions, JoinOptions


@overload
def concat(
objs: Iterable[Dataset],
dim: Hashable | DataArray | pd.Index,
objs: Iterable[T_Dataset],
dim: Hashable | T_DataArray | pd.Index,
data_vars: ConcatOptions | list[Hashable] = "all",
coords: ConcatOptions | list[Hashable] = "different",
compat: CompatOptions = "equals",
positions: Iterable[Iterable[int]] | None = None,
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
) -> Dataset:
) -> T_Dataset:
...


@overload
def concat(
objs: Iterable[DataArray],
dim: Hashable | DataArray | pd.Index,
objs: Iterable[T_DataArray],
dim: Hashable | T_DataArray | pd.Index,
data_vars: ConcatOptions | list[Hashable] = "all",
coords: ConcatOptions | list[Hashable] = "different",
compat: CompatOptions = "equals",
positions: Iterable[Iterable[int]] | None = None,
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
) -> DataArray:
) -> T_DataArray:
...


Expand Down Expand Up @@ -402,7 +401,7 @@ def process_subset_opt(opt, subset):

# determine dimensional coordinate names and a dict mapping name to DataArray
def _parse_datasets(
datasets: Iterable[Dataset],
datasets: Iterable[T_Dataset],
) -> tuple[dict[Hashable, Variable], dict[Hashable, int], set[Hashable], set[Hashable]]:

dims: set[Hashable] = set()
Expand All @@ -429,16 +428,16 @@ def _parse_datasets(


def _dataset_concat(
datasets: list[Dataset],
dim: str | DataArray | pd.Index,
datasets: list[T_Dataset],
dim: str | T_DataArray | pd.Index,
data_vars: str | list[str],
coords: str | list[str],
compat: CompatOptions,
positions: Iterable[Iterable[int]] | None,
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
) -> Dataset:
) -> T_Dataset:
"""
Concatenate a sequence of datasets along a new or existing dimension
"""
Expand Down Expand Up @@ -482,7 +481,8 @@ def _dataset_concat(

# case where concat dimension is a coordinate or data_var but not a dimension
if (dim in coord_names or dim in data_names) and dim not in dim_names:
datasets = [ds.expand_dims(dim) for ds in datasets]
# TODO: Overriding type because .expand_dims has incorrect typing:
datasets = [cast(T_Dataset, ds.expand_dims(dim)) for ds in datasets]

# determine which variables to concatenate
concat_over, equals, concat_dim_lengths = _calc_concat_over(
Expand Down Expand Up @@ -590,7 +590,7 @@ def get_indexes(name):
# preserves original variable order
result_vars[name] = result_vars.pop(name)

result = Dataset(result_vars, attrs=result_attrs)
result = cast(T_Dataset, Dataset(result_vars, attrs=result_attrs))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the actual result value is hardcoded to be of type Dataset which would not match T_Dataset if datasets contains some subclass of xarray.Dataset.

Suggestion to try (which may not need the cast):
result = type(datasets[0])(result_vars, attrs=result_attrs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice find! I do wonder if we have this issue in more places?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is quite endemic to hardcode Dataset / DataArray, rather than use the class of the argument.

We would need to make some widespread changes for this to work correctly — I think that would be welcome. We'd probably need to find a good way to test this across methods too (though maybe mypy could help there?).

And we might need a property of Dataset for what to class to use when returning a DataArray, which becomes more complicated.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this broke somebody's subclass that doesn't support the attrs argument: #6827


absent_coord_names = coord_names - set(result.variables)
if absent_coord_names:
Expand Down Expand Up @@ -618,16 +618,16 @@ def get_indexes(name):


def _dataarray_concat(
arrays: Iterable[DataArray],
dim: str | DataArray | pd.Index,
arrays: Iterable[T_DataArray],
dim: str | T_DataArray | pd.Index,
data_vars: str | list[str],
coords: str | list[str],
compat: CompatOptions,
positions: Iterable[Iterable[int]] | None,
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
) -> DataArray:
) -> T_DataArray:
from .dataarray import DataArray

arrays = list(arrays)
Expand All @@ -650,7 +650,8 @@ def _dataarray_concat(
if compat == "identical":
raise ValueError("array names not identical")
else:
arr = arr.rename(name)
# TODO: Overriding type because .rename has incorrect typing:
arr = cast(T_DataArray, arr.rename(name))
datasets.append(arr._to_temp_dataset())

ds = _dataset_concat(
Expand Down