diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 34cd2c82d92..f7cc30b9eab 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Hashable, Iterable, overload +from typing import TYPE_CHECKING, Any, Hashable, Iterable, cast, overload import pandas as pd @@ -14,19 +14,18 @@ merge_attrs, merge_collected, ) +from .types import T_DataArray, T_Dataset from .variable import Variable from .variable import concat as concat_vars if TYPE_CHECKING: - from .dataarray import DataArray - from .dataset import Dataset from .types import CombineAttrsOptions, CompatOptions, ConcatOptions, JoinOptions @overload def concat( - objs: Iterable[Dataset], - dim: Hashable | DataArray | pd.Index, + objs: Iterable[T_Dataset], + dim: Hashable | T_DataArray | pd.Index, data_vars: ConcatOptions | list[Hashable] = "all", coords: ConcatOptions | list[Hashable] = "different", compat: CompatOptions = "equals", @@ -34,14 +33,14 @@ def concat( fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", -) -> Dataset: +) -> T_Dataset: ... @overload def concat( - objs: Iterable[DataArray], - dim: Hashable | DataArray | pd.Index, + objs: Iterable[T_DataArray], + dim: Hashable | T_DataArray | pd.Index, data_vars: ConcatOptions | list[Hashable] = "all", coords: ConcatOptions | list[Hashable] = "different", compat: CompatOptions = "equals", @@ -49,7 +48,7 @@ def concat( fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", -) -> DataArray: +) -> T_DataArray: ... @@ -402,7 +401,7 @@ def process_subset_opt(opt, subset): # determine dimensional coordinate names and a dict mapping name to DataArray def _parse_datasets( - datasets: Iterable[Dataset], + datasets: Iterable[T_Dataset], ) -> tuple[dict[Hashable, Variable], dict[Hashable, int], set[Hashable], set[Hashable]]: dims: set[Hashable] = set() @@ -429,8 +428,8 @@ def _parse_datasets( def _dataset_concat( - datasets: list[Dataset], - dim: str | DataArray | pd.Index, + datasets: list[T_Dataset], + dim: str | T_DataArray | pd.Index, data_vars: str | list[str], coords: str | list[str], compat: CompatOptions, @@ -438,7 +437,7 @@ def _dataset_concat( fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", -) -> Dataset: +) -> T_Dataset: """ Concatenate a sequence of datasets along a new or existing dimension """ @@ -482,7 +481,8 @@ def _dataset_concat( # case where concat dimension is a coordinate or data_var but not a dimension if (dim in coord_names or dim in data_names) and dim not in dim_names: - datasets = [ds.expand_dims(dim) for ds in datasets] + # TODO: Overriding type because .expand_dims has incorrect typing: + datasets = [cast(T_Dataset, ds.expand_dims(dim)) for ds in datasets] # determine which variables to concatenate concat_over, equals, concat_dim_lengths = _calc_concat_over( @@ -590,7 +590,7 @@ def get_indexes(name): # preserves original variable order result_vars[name] = result_vars.pop(name) - result = Dataset(result_vars, attrs=result_attrs) + result = type(datasets[0])(result_vars, attrs=result_attrs) absent_coord_names = coord_names - set(result.variables) if absent_coord_names: @@ -618,8 +618,8 @@ def get_indexes(name): def _dataarray_concat( - arrays: Iterable[DataArray], - dim: str | DataArray | pd.Index, + arrays: Iterable[T_DataArray], + dim: str | T_DataArray | pd.Index, data_vars: str | list[str], coords: str | list[str], compat: CompatOptions, @@ -627,7 +627,7 @@ def _dataarray_concat( fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", -) -> DataArray: +) -> T_DataArray: from .dataarray import DataArray arrays = list(arrays) @@ -650,7 +650,8 @@ def _dataarray_concat( if compat == "identical": raise ValueError("array names not identical") else: - arr = arr.rename(name) + # TODO: Overriding type because .rename has incorrect typing: + arr = cast(T_DataArray, arr.rename(name)) datasets.append(arr._to_temp_dataset()) ds = _dataset_concat(