From 19ec0fe2efec0150663607d05e6eaa84dd706c26 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 15 Jul 2022 00:26:51 +0200 Subject: [PATCH 1/9] Switch to T_DataArray in concat --- xarray/core/concat.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 92e81dca4e3..cda9589c8b5 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -6,6 +6,7 @@ from . import dtypes, utils from .alignment import align +from .types import T_DataArray from .duck_array_ops import lazy_array_equiv from .indexes import Index, PandasIndex from .merge import ( @@ -26,7 +27,7 @@ @overload def concat( objs: Iterable[Dataset], - dim: Hashable | DataArray | pd.Index, + dim: Hashable | T_DataArray | pd.Index, data_vars: ConcatOptions | list[Hashable] = "all", coords: ConcatOptions | list[Hashable] = "different", compat: CompatOptions = "equals", @@ -40,8 +41,8 @@ def concat( @overload def concat( - objs: Iterable[DataArray], - dim: Hashable | DataArray | pd.Index, + objs: Iterable[T_DataArray], + dim: Hashable | T_DataArray | pd.Index, data_vars: ConcatOptions | list[Hashable] = "all", coords: ConcatOptions | list[Hashable] = "different", compat: CompatOptions = "equals", @@ -49,7 +50,7 @@ def concat( fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", -) -> DataArray: +) -> T_DataArray: ... @@ -430,7 +431,7 @@ def _parse_datasets( def _dataset_concat( datasets: list[Dataset], - dim: str | DataArray | pd.Index, + dim: str | T_DataArray | pd.Index, data_vars: str | list[str], coords: str | list[str], compat: CompatOptions, @@ -618,8 +619,8 @@ def get_indexes(name): def _dataarray_concat( - arrays: Iterable[DataArray], - dim: str | DataArray | pd.Index, + arrays: Iterable[T_DataArray], + dim: str | T_DataArray | pd.Index, data_vars: str | list[str], coords: str | list[str], compat: CompatOptions, @@ -627,7 +628,7 @@ def _dataarray_concat( fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", -) -> DataArray: +) -> T_DataArray: from .dataarray import DataArray arrays = list(arrays) From 7e239cab999b7513c549a0b54090ec3d68bbe199 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 15 Jul 2022 00:30:20 +0200 Subject: [PATCH 2/9] Switch tp T_Dataset in concat --- xarray/core/concat.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index cda9589c8b5..fd0677afb54 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -6,7 +6,7 @@ from . import dtypes, utils from .alignment import align -from .types import T_DataArray +from .types import T_DataArray, T_Dataset from .duck_array_ops import lazy_array_equiv from .indexes import Index, PandasIndex from .merge import ( @@ -26,7 +26,7 @@ @overload def concat( - objs: Iterable[Dataset], + objs: Iterable[T_Dataset], dim: Hashable | T_DataArray | pd.Index, data_vars: ConcatOptions | list[Hashable] = "all", coords: ConcatOptions | list[Hashable] = "different", @@ -35,7 +35,7 @@ def concat( fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", -) -> Dataset: +) -> T_Dataset: ... @@ -403,7 +403,7 @@ def process_subset_opt(opt, subset): # determine dimensional coordinate names and a dict mapping name to DataArray def _parse_datasets( - datasets: Iterable[Dataset], + datasets: Iterable[T_Dataset], ) -> tuple[dict[Hashable, Variable], dict[Hashable, int], set[Hashable], set[Hashable]]: dims: set[Hashable] = set() @@ -430,7 +430,7 @@ def _parse_datasets( def _dataset_concat( - datasets: list[Dataset], + datasets: list[T_Dataset], dim: str | T_DataArray | pd.Index, data_vars: str | list[str], coords: str | list[str], @@ -439,7 +439,7 @@ def _dataset_concat( fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", -) -> Dataset: +) -> T_Dataset: """ Concatenate a sequence of datasets along a new or existing dimension """ From 55963e8929954461fc39e12d22511d1b1465c3fe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Jul 2022 22:32:34 +0000 Subject: [PATCH 3/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/concat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index fd0677afb54..3a0206a2919 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -6,7 +6,6 @@ from . import dtypes, utils from .alignment import align -from .types import T_DataArray, T_Dataset from .duck_array_ops import lazy_array_equiv from .indexes import Index, PandasIndex from .merge import ( @@ -15,6 +14,7 @@ merge_attrs, merge_collected, ) +from .types import T_DataArray, T_Dataset from .variable import Variable from .variable import concat as concat_vars From 216cb99fce2e6dd0623814618f700ef0e67a2a4c Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 15 Jul 2022 10:49:13 +0200 Subject: [PATCH 4/9] Update concat.py --- xarray/core/concat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 3a0206a2919..4bd803458a1 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -431,7 +431,7 @@ def _parse_datasets( def _dataset_concat( datasets: list[T_Dataset], - dim: str | T_DataArray | pd.Index, + dim: str | T_DataArray | pd.Index, data_vars: str | list[str], coords: str | list[str], compat: CompatOptions, From c190b36ceda84b0af0e7f30c009233cda1af2b48 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Jul 2022 08:51:07 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/concat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 4bd803458a1..3a0206a2919 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -431,7 +431,7 @@ def _parse_datasets( def _dataset_concat( datasets: list[T_Dataset], - dim: str | T_DataArray | pd.Index, + dim: str | T_DataArray | pd.Index, data_vars: str | list[str], coords: str | list[str], compat: CompatOptions, From 8abb7b0e335f9b614cc1ef405d328e826c99e14f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 16 Jul 2022 17:38:50 +0200 Subject: [PATCH 6/9] cast types --- xarray/core/concat.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 4bd803458a1..aa40520627f 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Hashable, Iterable, overload +from typing import TYPE_CHECKING, Any, Hashable, Iterable, cast, overload import pandas as pd @@ -483,7 +483,8 @@ def _dataset_concat( # case where concat dimension is a coordinate or data_var but not a dimension if (dim in coord_names or dim in data_names) and dim not in dim_names: - datasets = [ds.expand_dims(dim) for ds in datasets] + # TODO: Overriding type because .expand_dims has incorrect typing: + datasets = [cast(T_Dataset, ds.expand_dims(dim)) for ds in datasets] # determine which variables to concatenate concat_over, equals, concat_dim_lengths = _calc_concat_over( @@ -591,7 +592,7 @@ def get_indexes(name): # preserves original variable order result_vars[name] = result_vars.pop(name) - result = Dataset(result_vars, attrs=result_attrs) + result = cast(T_Dataset, Dataset(result_vars, attrs=result_attrs)) absent_coord_names = coord_names - set(result.variables) if absent_coord_names: @@ -651,7 +652,8 @@ def _dataarray_concat( if compat == "identical": raise ValueError("array names not identical") else: - arr = arr.rename(name) + # TODO: Overriding type because .rename has incorrect typing: + arr = cast(T_DataArray, arr.rename(name)) datasets.append(arr._to_temp_dataset()) ds = _dataset_concat( From 55e1b4ddaf69fac096d785ec95af1d7e029b80e8 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 16 Jul 2022 18:32:33 +0200 Subject: [PATCH 7/9] Update concat.py --- xarray/core/concat.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index dcefc6acd6b..3bfe9141e54 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -19,8 +19,6 @@ from .variable import concat as concat_vars if TYPE_CHECKING: - from .dataarray import DataArray - from .dataset import Dataset from .types import CombineAttrsOptions, CompatOptions, ConcatOptions, JoinOptions From c979a57e532b42a1e0696c18069795533edec42a Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 16 Jul 2022 20:11:47 +0200 Subject: [PATCH 8/9] Update concat.py --- xarray/core/concat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 3bfe9141e54..0dfc89483a6 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -590,7 +590,8 @@ def get_indexes(name): # preserves original variable order result_vars[name] = result_vars.pop(name) - result = cast(T_Dataset, Dataset(result_vars, attrs=result_attrs)) + # result = cast(T_Dataset, Dataset(result_vars, attrs=result_attrs)) + result = type(datasets[0])(result_vars, attrs=result_attrs) absent_coord_names = coord_names - set(result.variables) if absent_coord_names: From 63e0f1d66de2de1bb8444a5b1c6ff743d7526419 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sat, 16 Jul 2022 22:42:49 +0200 Subject: [PATCH 9/9] Update concat.py --- xarray/core/concat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 592829dd49c..f7cc30b9eab 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -590,7 +590,6 @@ def get_indexes(name): # preserves original variable order result_vars[name] = result_vars.pop(name) - # result = cast(T_Dataset, Dataset(result_vars, attrs=result_attrs)) result = type(datasets[0])(result_vars, attrs=result_attrs) absent_coord_names = coord_names - set(result.variables)