From 2705c63e0c03a21d2bbce3a337fac60dd6f6da59 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 13 Aug 2021 15:41:42 +0200 Subject: [PATCH 01/20] Use same bool validator as other inputs (#5703) --- xarray/core/options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/options.py b/xarray/core/options.py index 71358916243..1d916ff0f7c 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -56,7 +56,7 @@ def _positive_integer(value): FILE_CACHE_MAXSIZE: _positive_integer, KEEP_ATTRS: lambda choice: choice in [True, False, "default"], WARN_FOR_UNCLOSED_FILES: lambda value: isinstance(value, bool), - USE_BOTTLENECK: lambda choice: choice in [True, False], + USE_BOTTLENECK: lambda value: isinstance(value, bool), } From 41099669c63c7dbb51b112fdec09dfae07379fe7 Mon Sep 17 00:00:00 2001 From: Tomas Chor Date: Mon, 16 Aug 2021 19:29:54 -0700 Subject: [PATCH 02/20] Improves rendering of complex LaTeX expressions as `long_name`s when plotting (#5682) * try to wrap texts differently if they're latex code * added space around modulo operator * prevents breaking up of long words * added test for label_from_attrs() function with a long latex string * Update xarray/plot/utils.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * Update xarray/tests/test_utils.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * moved test_latex_name_isnt_split() to test_plots * Apply suggestions from code review * fixed test_latex_name_isnt_split() * ran code through pre-commit * updated whats-new.rst Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> Co-authored-by: Deepak Cherian --- doc/whats-new.rst | 2 ++ xarray/plot/utils.py | 8 +++++++- xarray/tests/test_plot.py | 7 +++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9ac9639b8c1..52114be991b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,6 +22,8 @@ v0.19.1 (unreleased) New Features ~~~~~~~~~~~~ +- Xarray now does a better job rendering variable names that are long LaTeX sequences when plotting (:issue:`5681`, :pull:`5682`). + By `Tomas Chor `_. - Add a option to disable the use of ``bottleneck`` (:pull:`5560`) By `Justus Magin `_. - Added ``**kwargs`` argument to :py:meth:`open_rasterio` to access overviews (:issue:`3269`). diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index f2f296096a5..af5859c1f14 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -490,7 +490,13 @@ def _get_units_from_attrs(da): else: units = _get_units_from_attrs(da) - return "\n".join(textwrap.wrap(name + extra + units, 30)) + # Treat `name` differently if it's a latex sequence + if name.startswith("$") and (name.count("$") % 2 == 0): + return "$\n$".join( + textwrap.wrap(name + extra + units, 60, break_long_words=False) + ) + else: + return "\n".join(textwrap.wrap(name + extra + units, 30)) def _interval_to_mid_points(array): diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index ee8bafb8fa7..e358bd47485 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2950,3 +2950,10 @@ def test_datarray_scatter(x, y, z, hue, markersize, row, col, add_legend, add_co add_legend=add_legend, add_colorbar=add_colorbar, ) + + +def test_latex_name_isnt_split(): + da = xr.DataArray() + long_latex_name = r"$Ra_s = \mathrm{mean}(\epsilon_k) / \mu M^2_\infty$" + da.attrs = dict(long_name=long_latex_name) + assert label_from_attrs(da) == long_latex_name From de867e613dc14ff80d7b898d85c4171121ad353d Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 18 Aug 2021 04:34:05 +0200 Subject: [PATCH 03/20] Fix errors in test_latex_name_isnt_split for min environments (#5710) --- xarray/tests/test_plot.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index e358bd47485..0c4dd5a29f2 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -185,6 +185,11 @@ def test_label_from_attrs(self): da.attrs.pop("units") assert "a" == label_from_attrs(da) + # Latex strings can be longer without needing a new line: + long_latex_name = r"$Ra_s = \mathrm{mean}(\epsilon_k) / \mu M^2_\infty$" + da.attrs = dict(long_name=long_latex_name) + assert label_from_attrs(da) == long_latex_name + def test1d(self): self.darray[:, 0, 0].plot() @@ -2950,10 +2955,3 @@ def test_datarray_scatter(x, y, z, hue, markersize, row, col, add_legend, add_co add_legend=add_legend, add_colorbar=add_colorbar, ) - - -def test_latex_name_isnt_split(): - da = xr.DataArray() - long_latex_name = r"$Ra_s = \mathrm{mean}(\epsilon_k) / \mu M^2_\infty$" - da.attrs = dict(long_name=long_latex_name) - assert label_from_attrs(da) == long_latex_name From 56cc0778d5edcb0173425b5850b834457c790210 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Wed, 18 Aug 2021 22:12:37 -0700 Subject: [PATCH 04/20] Use isort's float-to-top (#5695) * Use isort's float-to-top * . --- properties/test_encode_decode.py | 3 ++- setup.cfg | 2 +- xarray/core/parallel.py | 19 ++++++++++--------- xarray/tests/test_units.py | 11 ++++++----- 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/properties/test_encode_decode.py b/properties/test_encode_decode.py index 221083e16a1..4b0643cb2fe 100644 --- a/properties/test_encode_decode.py +++ b/properties/test_encode_decode.py @@ -4,9 +4,10 @@ These ones pass, just as you'd hope! """ -import pytest # isort:skip +import pytest pytest.importorskip("hypothesis") +# isort: split import hypothesis.extra.numpy as npst import hypothesis.strategies as st diff --git a/setup.cfg b/setup.cfg index c44d207bf0f..2059d471e23 100644 --- a/setup.cfg +++ b/setup.cfg @@ -161,7 +161,7 @@ exclude= [isort] profile = black skip_gitignore = true -force_to_top = true +float_to_top = true default_section = THIRDPARTY known_first_party = xarray diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 2c7f4249b5e..20ec3608ebb 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -1,12 +1,3 @@ -try: - import dask - import dask.array - from dask.array.utils import meta_from_array - from dask.highlevelgraph import HighLevelGraph - -except ImportError: - pass - import collections import itertools import operator @@ -31,6 +22,16 @@ from .dataarray import DataArray from .dataset import Dataset +try: + import dask + import dask.array + from dask.array.utils import meta_from_array + from dask.highlevelgraph import HighLevelGraph + +except ImportError: + pass + + T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 2140047f38e..543100ef98c 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -5,11 +5,6 @@ import pandas as pd import pytest -try: - import matplotlib.pyplot as plt -except ImportError: - pass - import xarray as xr from xarray.core import dtypes, duck_array_ops @@ -23,6 +18,12 @@ from .test_plot import PlotTestCase from .test_variable import _PAD_XR_NP_ARGS +try: + import matplotlib.pyplot as plt +except ImportError: + pass + + pint = pytest.importorskip("pint") DimensionalityError = pint.errors.DimensionalityError From 22548f844888c19ba4a0025057e42396c6c90350 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Thu, 19 Aug 2021 10:26:11 -0700 Subject: [PATCH 05/20] Whatsnew for float-to-top (#5714) * Use isort's float-to-top * . * whatsnew for float-to-top --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 52114be991b..69329a73971 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -55,6 +55,8 @@ Internal Changes By `Benoit Bovy `_. - Improve the performance of reprs for large datasets or dataarrays. (:pull:`5661`) By `Jimmy Westling `_. +- Use isort's `float_to_top` config. (:pull:`5695`). + By `Maximilian Roos `_. .. _whats-new.0.19.0: From 378b9020afd5e3f3505a8b3b9d56a27b61c97898 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Thu, 19 Aug 2021 10:27:00 -0700 Subject: [PATCH 06/20] Change annotations to allow str keys (#5690) * Change typing to allow str keys * Change all incoming Mapping types * Add in some annotated tests * whatsnew --- doc/whats-new.rst | 2 ++ properties/test_pandas_roundtrip.py | 10 +++--- xarray/core/accessor_str.py | 4 +-- xarray/core/common.py | 8 ++--- xarray/core/computation.py | 2 +- xarray/core/coordinates.py | 10 +++--- xarray/core/dataarray.py | 24 +++++++-------- xarray/core/dataset.py | 48 ++++++++++++++--------------- xarray/core/merge.py | 4 +-- xarray/core/utils.py | 2 +- xarray/core/variable.py | 6 ++-- xarray/tests/test_dataset.py | 9 ++++++ 12 files changed, 70 insertions(+), 59 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 69329a73971..4f79a37eb4b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -53,6 +53,8 @@ Internal Changes By `Deepak Cherian `_. - Explicit indexes refactor: decouple ``xarray.Index``` from ``xarray.Variable`` (:pull:`5636`). By `Benoit Bovy `_. +- Fix ``Mapping`` argument typing to allow mypy to pass on ``str`` keys (:pull:`5690`). + By `Maximilian Roos `_. - Improve the performance of reprs for large datasets or dataarrays. (:pull:`5661`) By `Jimmy Westling `_. - Use isort's `float_to_top` config. (:pull:`5695`). diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index 5fc097f1f5e..e8cef95c029 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -28,7 +28,7 @@ @st.composite -def datasets_1d_vars(draw): +def datasets_1d_vars(draw) -> xr.Dataset: """Generate datasets with only 1D variables Suitable for converting to pandas dataframes. @@ -49,7 +49,7 @@ def datasets_1d_vars(draw): @given(st.data(), an_array) -def test_roundtrip_dataarray(data, arr): +def test_roundtrip_dataarray(data, arr) -> None: names = data.draw( st.lists(st.text(), min_size=arr.ndim, max_size=arr.ndim, unique=True).map( tuple @@ -62,7 +62,7 @@ def test_roundtrip_dataarray(data, arr): @given(datasets_1d_vars()) -def test_roundtrip_dataset(dataset): +def test_roundtrip_dataset(dataset) -> None: df = dataset.to_dataframe() assert isinstance(df, pd.DataFrame) roundtripped = xr.Dataset(df) @@ -70,7 +70,7 @@ def test_roundtrip_dataset(dataset): @given(numeric_series, st.text()) -def test_roundtrip_pandas_series(ser, ix_name): +def test_roundtrip_pandas_series(ser, ix_name) -> None: # Need to name the index, otherwise Xarray calls it 'dim_0'. ser.index.name = ix_name arr = xr.DataArray(ser) @@ -87,7 +87,7 @@ def test_roundtrip_pandas_series(ser, ix_name): @pytest.mark.xfail @given(numeric_homogeneous_dataframe) -def test_roundtrip_pandas_dataframe(df): +def test_roundtrip_pandas_dataframe(df) -> None: # Need to name the indexes, otherwise Xarray names them 'dim_0', 'dim_1'. df.index.name = "rows" df.columns.name = "cols" diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index e3c35d6e4b6..0f0a256b77a 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -114,7 +114,7 @@ def _apply_str_ufunc( obj: Any, dtype: Union[str, np.dtype, Type] = None, output_core_dims: Union[list, tuple] = ((),), - output_sizes: Mapping[Hashable, int] = None, + output_sizes: Mapping[Any, int] = None, func_args: Tuple = (), func_kwargs: Mapping = {}, ) -> Any: @@ -227,7 +227,7 @@ def _apply( func: Callable, dtype: Union[str, np.dtype, Type] = None, output_core_dims: Union[list, tuple] = ((),), - output_sizes: Mapping[Hashable, int] = None, + output_sizes: Mapping[Any, int] = None, func_args: Tuple = (), func_kwargs: Mapping = {}, ) -> Any: diff --git a/xarray/core/common.py b/xarray/core/common.py index ab822f576d3..d3001532aa0 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -818,7 +818,7 @@ def weighted( def rolling( self, - dim: Mapping[Hashable, int] = None, + dim: Mapping[Any, int] = None, min_periods: int = None, center: Union[bool, Mapping[Hashable, bool]] = False, **window_kwargs: int, @@ -892,7 +892,7 @@ def rolling( def rolling_exp( self, - window: Mapping[Hashable, int] = None, + window: Mapping[Any, int] = None, window_type: str = "span", **window_kwargs, ): @@ -933,7 +933,7 @@ def rolling_exp( def coarsen( self, - dim: Mapping[Hashable, int] = None, + dim: Mapping[Any, int] = None, boundary: str = "exact", side: Union[str, Mapping[Hashable, str]] = "left", coord_func: str = "mean", @@ -1009,7 +1009,7 @@ def coarsen( def resample( self, - indexer: Mapping[Hashable, str] = None, + indexer: Mapping[Any, str] = None, skipna=None, closed: str = None, label: str = None, diff --git a/xarray/core/computation.py b/xarray/core/computation.py index cd9e22d90db..5bfb14793bb 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -400,7 +400,7 @@ def apply_dict_of_variables_vfunc( def _fast_dataset( - variables: Dict[Hashable, Variable], coord_variables: Mapping[Hashable, Variable] + variables: Dict[Hashable, Variable], coord_variables: Mapping[Any, Variable] ) -> "Dataset": """Create a dataset as quickly as possible. diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 767b76d0d12..56afdb9774a 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -158,7 +158,7 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index: return pd.MultiIndex(level_list, code_list, names=names) - def update(self, other: Mapping[Hashable, Any]) -> None: + def update(self, other: Mapping[Any, Any]) -> None: other_vars = getattr(other, "variables", other) coords, indexes = merge_coords( [self.variables, other_vars], priority_arg=1, indexes=self.xindexes @@ -270,7 +270,7 @@ def to_dataset(self) -> "Dataset": return self._data._copy_listed(names) def _update_coords( - self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, Index] + self, coords: Dict[Hashable, Variable], indexes: Mapping[Any, Index] ) -> None: from .dataset import calculate_dimensions @@ -333,7 +333,7 @@ def __getitem__(self, key: Hashable) -> "DataArray": return self._data._getitem_coord(key) def _update_coords( - self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, Index] + self, coords: Dict[Hashable, Variable], indexes: Mapping[Any, Index] ) -> None: from .dataset import calculate_dimensions @@ -376,7 +376,7 @@ def _ipython_key_completions_(self): def assert_coordinate_consistent( - obj: Union["DataArray", "Dataset"], coords: Mapping[Hashable, Variable] + obj: Union["DataArray", "Dataset"], coords: Mapping[Any, Variable] ) -> None: """Make sure the dimension coordinate of obj is consistent with coords. @@ -394,7 +394,7 @@ def assert_coordinate_consistent( def remap_label_indexers( obj: Union["DataArray", "Dataset"], - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, method: str = None, tolerance=None, **indexers_kwargs: Any, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 900af885319..e86bae08a3f 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -462,7 +462,7 @@ def _replace_maybe_drop_dims( ) return self._replace(variable, coords, name, indexes=indexes) - def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray": + def _overwrite_indexes(self, indexes: Mapping[Any, Any]) -> "DataArray": if not len(indexes): return self coords = self._coords.copy() @@ -792,7 +792,7 @@ def attrs(self) -> Dict[Hashable, Any]: return self.variable.attrs @attrs.setter - def attrs(self, value: Mapping[Hashable, Any]) -> None: + def attrs(self, value: Mapping[Any, Any]) -> None: # Disable type checking to work around mypy bug - see mypy#4167 self.variable.attrs = value # type: ignore[assignment] @@ -803,7 +803,7 @@ def encoding(self) -> Dict[Hashable, Any]: return self.variable.encoding @encoding.setter - def encoding(self, value: Mapping[Hashable, Any]) -> None: + def encoding(self, value: Mapping[Any, Any]) -> None: self.variable.encoding = value @property @@ -1110,7 +1110,7 @@ def chunk( def isel( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, drop: bool = False, missing_dims: str = "raise", **indexers_kwargs: Any, @@ -1193,7 +1193,7 @@ def isel( def sel( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, method: str = None, tolerance=None, drop: bool = False, @@ -1498,7 +1498,7 @@ def reindex_like( def reindex( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, method: str = None, tolerance=None, copy: bool = True, @@ -1591,7 +1591,7 @@ def reindex( def interp( self, - coords: Mapping[Hashable, Any] = None, + coords: Mapping[Any, Any] = None, method: str = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] = None, @@ -1815,7 +1815,7 @@ def rename( return self._replace(name=new_name_or_name_dict) def swap_dims( - self, dims_dict: Mapping[Hashable, Hashable] = None, **dims_kwargs + self, dims_dict: Mapping[Any, Hashable] = None, **dims_kwargs ) -> "DataArray": """Returns a new DataArray with swapped dimensions. @@ -2333,7 +2333,7 @@ def drop( def drop_sel( self, - labels: Mapping[Hashable, Any] = None, + labels: Mapping[Any, Any] = None, *, errors: str = "raise", **labels_kwargs, @@ -3163,7 +3163,7 @@ def diff(self, dim: Hashable, n: int = 1, label: Hashable = "upper") -> "DataArr def shift( self, - shifts: Mapping[Hashable, int] = None, + shifts: Mapping[Any, int] = None, fill_value: Any = dtypes.NA, **shifts_kwargs: int, ) -> "DataArray": @@ -3210,7 +3210,7 @@ def shift( def roll( self, - shifts: Mapping[Hashable, int] = None, + shifts: Mapping[Any, int] = None, roll_coords: bool = None, **shifts_kwargs: int, ) -> "DataArray": @@ -4433,7 +4433,7 @@ def argmax( def query( self, - queries: Mapping[Hashable, Any] = None, + queries: Mapping[Any, Any] = None, parser: str = "pandas", engine: str = None, missing_dims: str = "raise", diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4bfc1ccbdf1..90c395ed39b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -185,7 +185,7 @@ def _get_virtual_variable( return ref_name, var_name, virtual_var -def calculate_dimensions(variables: Mapping[Hashable, Variable]) -> Dict[Hashable, int]: +def calculate_dimensions(variables: Mapping[Any, Variable]) -> Dict[Hashable, int]: """Calculate the dimensions corresponding to a set of variables. Returns dictionary mapping from dimension names to sizes. Raises ValueError @@ -213,7 +213,7 @@ def calculate_dimensions(variables: Mapping[Hashable, Variable]) -> Dict[Hashabl def merge_indexes( indexes: Mapping[Hashable, Union[Hashable, Sequence[Hashable]]], - variables: Mapping[Hashable, Variable], + variables: Mapping[Any, Variable], coord_names: Set[Hashable], append: bool = False, ) -> Tuple[Dict[Hashable, Variable], Set[Hashable]]: @@ -297,9 +297,9 @@ def merge_indexes( def split_indexes( dims_or_levels: Union[Hashable, Sequence[Hashable]], - variables: Mapping[Hashable, Variable], + variables: Mapping[Any, Variable], coord_names: Set[Hashable], - level_coords: Mapping[Hashable, Hashable], + level_coords: Mapping[Any, Hashable], drop: bool = False, ) -> Tuple[Dict[Hashable, Variable], Set[Hashable]]: """Extract (multi-)indexes (levels) as variables. @@ -559,7 +559,7 @@ class _LocIndexer: def __init__(self, dataset: "Dataset"): self.dataset = dataset - def __getitem__(self, key: Mapping[Hashable, Any]) -> "Dataset": + def __getitem__(self, key: Mapping[Any, Any]) -> "Dataset": if not utils.is_dict_like(key): raise TypeError("can only lookup dictionaries from Dataset.loc") return self.dataset.sel(key) @@ -730,9 +730,9 @@ def __init__( self, # could make a VariableArgs to use more generally, and refine these # categories - data_vars: Mapping[Hashable, Any] = None, - coords: Mapping[Hashable, Any] = None, - attrs: Mapping[Hashable, Any] = None, + data_vars: Mapping[Any, Any] = None, + coords: Mapping[Any, Any] = None, + attrs: Mapping[Any, Any] = None, ): # TODO(shoyer): expose indexes as a public argument in __init__ @@ -793,7 +793,7 @@ def attrs(self) -> Dict[Hashable, Any]: return self._attrs @attrs.setter - def attrs(self, value: Mapping[Hashable, Any]) -> None: + def attrs(self, value: Mapping[Any, Any]) -> None: self._attrs = dict(value) @property @@ -2164,7 +2164,7 @@ def chunk( return self._replace(variables) def _validate_indexers( - self, indexers: Mapping[Hashable, Any], missing_dims: str = "raise" + self, indexers: Mapping[Any, Any], missing_dims: str = "raise" ) -> Iterator[Tuple[Hashable, Union[int, slice, np.ndarray, Variable]]]: """Here we make sure + indexer has a valid keys @@ -2208,7 +2208,7 @@ def _validate_indexers( yield k, v def _validate_interp_indexers( - self, indexers: Mapping[Hashable, Any] + self, indexers: Mapping[Any, Any] ) -> Iterator[Tuple[Hashable, Variable]]: """Variant of _validate_indexers to be used for interpolation""" for k, v in self._validate_indexers(indexers): @@ -2269,7 +2269,7 @@ def _get_indexers_coords_and_indexes(self, indexers): def isel( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, drop: bool = False, missing_dims: str = "raise", **indexers_kwargs: Any, @@ -2361,7 +2361,7 @@ def isel( def _isel_fancy( self, - indexers: Mapping[Hashable, Any], + indexers: Mapping[Any, Any], *, drop: bool, missing_dims: str = "raise", @@ -2403,7 +2403,7 @@ def _isel_fancy( def sel( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, method: str = None, tolerance: Number = None, drop: bool = False, @@ -2711,7 +2711,7 @@ def reindex_like( def reindex( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, method: str = None, tolerance: Number = None, copy: bool = True, @@ -2921,7 +2921,7 @@ def reindex( def _reindex( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, method: str = None, tolerance: Number = None, copy: bool = True, @@ -2955,7 +2955,7 @@ def _reindex( def interp( self, - coords: Mapping[Hashable, Any] = None, + coords: Mapping[Any, Any] = None, method: str = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] = None, @@ -3325,7 +3325,7 @@ def _rename_all(self, name_dict, dims_dict): def rename( self, - name_dict: Mapping[Hashable, Hashable] = None, + name_dict: Mapping[Any, Hashable] = None, **names: Hashable, ) -> "Dataset": """Returns a new object with renamed variables and dimensions. @@ -3366,7 +3366,7 @@ def rename( return self._replace(variables, coord_names, dims=dims, indexes=indexes) def rename_dims( - self, dims_dict: Mapping[Hashable, Hashable] = None, **dims: Hashable + self, dims_dict: Mapping[Any, Hashable] = None, **dims: Hashable ) -> "Dataset": """Returns a new object with renamed dimensions only. @@ -3411,7 +3411,7 @@ def rename_dims( return self._replace(variables, coord_names, dims=sizes, indexes=indexes) def rename_vars( - self, name_dict: Mapping[Hashable, Hashable] = None, **names: Hashable + self, name_dict: Mapping[Any, Hashable] = None, **names: Hashable ) -> "Dataset": """Returns a new object with renamed variables including coordinates @@ -3449,7 +3449,7 @@ def rename_vars( return self._replace(variables, coord_names, dims=dims, indexes=indexes) def swap_dims( - self, dims_dict: Mapping[Hashable, Hashable] = None, **dims_kwargs + self, dims_dict: Mapping[Any, Hashable] = None, **dims_kwargs ) -> "Dataset": """Returns a new object with swapped dimensions. @@ -5146,7 +5146,7 @@ def apply( return self.map(func, keep_attrs, args, **kwargs) def assign( - self, variables: Mapping[Hashable, Any] = None, **variables_kwargs: Hashable + self, variables: Mapping[Any, Any] = None, **variables_kwargs: Hashable ) -> "Dataset": """Assign new data variables to a Dataset, returning a new object with all the original variables in addition to the new ones. @@ -5322,7 +5322,7 @@ def to_pandas(self) -> Union[pd.Series, pd.DataFrame]: "Please use Dataset.to_dataframe() instead." % len(self.dims) ) - def _to_dataframe(self, ordered_dims: Mapping[Hashable, int]): + def _to_dataframe(self, ordered_dims: Mapping[Any, int]): columns = [k for k in self.variables if k not in self.dims] data = [ self._variables[k].set_dims(ordered_dims).values.reshape(-1) @@ -7391,7 +7391,7 @@ def argmax(self, dim=None, **kwargs): def query( self, - queries: Mapping[Hashable, Any] = None, + queries: Mapping[Any, Any] = None, parser: str = "pandas", engine: str = None, missing_dims: str = "raise", diff --git a/xarray/core/merge.py b/xarray/core/merge.py index b8b32bdaa01..eaa5b62b2a9 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -170,7 +170,7 @@ def _assert_compat_valid(compat): def merge_collected( grouped: Dict[Hashable, List[MergeElement]], - prioritized: Mapping[Hashable, MergeElement] = None, + prioritized: Mapping[Any, MergeElement] = None, compat: str = "minimal", combine_attrs="override", ) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: @@ -319,7 +319,7 @@ def collect_from_coordinates( def merge_coordinates_without_align( objects: "List[Coordinates]", - prioritized: Mapping[Hashable, MergeElement] = None, + prioritized: Mapping[Any, MergeElement] = None, exclude_dims: AbstractSet = frozenset(), combine_attrs: str = "override", ) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: diff --git a/xarray/core/utils.py b/xarray/core/utils.py index a139d2ef10a..57b7035c940 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -816,7 +816,7 @@ def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable: def drop_dims_from_indexers( - indexers: Mapping[Hashable, Any], + indexers: Mapping[Any, Any], dims: Union[list, Mapping[Hashable, int]], missing_dims: str, ) -> Mapping[Hashable, Any]: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 6b971389de7..2798a4ab956 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -876,7 +876,7 @@ def attrs(self) -> Dict[Hashable, Any]: return self._attrs @attrs.setter - def attrs(self, value: Mapping[Hashable, Any]) -> None: + def attrs(self, value: Mapping[Any, Any]) -> None: self._attrs = dict(value) @property @@ -1137,7 +1137,7 @@ def _to_dense(self): def isel( self: VariableType, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, missing_dims: str = "raise", **indexers_kwargs: Any, ) -> VariableType: @@ -1572,7 +1572,7 @@ def stack(self, dimensions=None, **dimensions_kwargs): return result def _unstack_once_full( - self, dims: Mapping[Hashable, int], old_dim: Hashable + self, dims: Mapping[Any, int], old_dim: Hashable ) -> "Variable": """ Unstacks the variable without needing an index. diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 9b8b7c748f1..1256a44ad81 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6802,3 +6802,12 @@ def test_from_pint_wrapping_dask(self): result = ds.as_numpy() expected = xr.Dataset({"a": ("x", arr)}, coords={"lat": ("x", arr * 2)}) assert_identical(result, expected) + + +def test_string_keys_typing() -> None: + """Tests that string keys to `variables` are permitted by mypy""" + + da = xr.DataArray(np.arange(10), dims=["x"]) + ds = xr.Dataset(dict(x=da)) + mapping = {"y": da} + ds.assign(variables=mapping) From b7737f29e58df923b1fc6ba28a961a259fa98950 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Aug 2021 00:15:23 +0200 Subject: [PATCH 07/20] Add typing to the OPTIONS dict (#5678) * Add typing to the OPTIONS dict * Update options.py * Update options.py * Don't use variables as keys * Update options.py * backwards compatibility * Update options.py * Update options.py * Update options.py * get TypedDict from typing_extensions * Update options.py * minor cleanup * Update options.py * Add Colormap as allowed type * linting, keep_attrs can be str or bool * display_expand* can be string or bool * TypedDict is in typing after 3.8 * Try out Literals for better type narrowing * Literal was added python 3.8 --- xarray/core/options.py | 153 +++++++++++++++++++++++++++-------------- 1 file changed, 103 insertions(+), 50 deletions(-) diff --git a/xarray/core/options.py b/xarray/core/options.py index 1d916ff0f7c..524362ce924 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -1,38 +1,91 @@ +import sys import warnings -ARITHMETIC_JOIN = "arithmetic_join" -CMAP_DIVERGENT = "cmap_divergent" -CMAP_SEQUENTIAL = "cmap_sequential" -DISPLAY_MAX_ROWS = "display_max_rows" -DISPLAY_STYLE = "display_style" -DISPLAY_WIDTH = "display_width" -DISPLAY_EXPAND_ATTRS = "display_expand_attrs" -DISPLAY_EXPAND_COORDS = "display_expand_coords" -DISPLAY_EXPAND_DATA_VARS = "display_expand_data_vars" -DISPLAY_EXPAND_DATA = "display_expand_data" -ENABLE_CFTIMEINDEX = "enable_cftimeindex" -FILE_CACHE_MAXSIZE = "file_cache_maxsize" -KEEP_ATTRS = "keep_attrs" -WARN_FOR_UNCLOSED_FILES = "warn_for_unclosed_files" -USE_BOTTLENECK = "use_bottleneck" - - -OPTIONS = { - ARITHMETIC_JOIN: "inner", - CMAP_DIVERGENT: "RdBu_r", - CMAP_SEQUENTIAL: "viridis", - DISPLAY_MAX_ROWS: 12, - DISPLAY_STYLE: "html", - DISPLAY_WIDTH: 80, - DISPLAY_EXPAND_ATTRS: "default", - DISPLAY_EXPAND_COORDS: "default", - DISPLAY_EXPAND_DATA_VARS: "default", - DISPLAY_EXPAND_DATA: "default", - ENABLE_CFTIMEINDEX: True, - FILE_CACHE_MAXSIZE: 128, - KEEP_ATTRS: "default", - WARN_FOR_UNCLOSED_FILES: False, - USE_BOTTLENECK: True, +# TODO: Remove this check once python 3.7 is not supported: +if sys.version_info >= (3, 8): + from typing import TYPE_CHECKING, Literal, TypedDict, Union + + if TYPE_CHECKING: + try: + from matplotlib.colors import Colormap + except ImportError: + Colormap = str + + class T_Options(TypedDict): + arithmetic_join: Literal["inner", "outer", "left", "right", "exact"] + cmap_divergent: Union[str, "Colormap"] + cmap_sequential: Union[str, "Colormap"] + display_max_rows: int + display_style: Literal["text", "html"] + display_width: int + display_expand_attrs: Literal["default", True, False] + display_expand_coords: Literal["default", True, False] + display_expand_data_vars: Literal["default", True, False] + display_expand_data: Literal["default", True, False] + enable_cftimeindex: bool + file_cache_maxsize: int + keep_attrs: Literal["default", True, False] + warn_for_unclosed_files: bool + use_bottleneck: bool + + +else: + # See GH5624, this is a convoluted way to allow type-checking to use + # `TypedDict` and `Literal` without requiring typing_extensions as a + # required dependency to _run_ the code (it is required to type-check). + try: + from typing import TYPE_CHECKING, Union + + from typing_extensions import Literal, TypedDict + + if TYPE_CHECKING: + try: + from matplotlib.colors import Colormap + except ImportError: + Colormap = str + + class T_Options(TypedDict): + arithmetic_join: Literal["inner", "outer", "left", "right", "exact"] + cmap_divergent: Union[str, "Colormap"] + cmap_sequential: Union[str, "Colormap"] + display_max_rows: int + display_style: Literal["text", "html"] + display_width: int + display_expand_attrs: Literal["default", True, False] + display_expand_coords: Literal["default", True, False] + display_expand_data_vars: Literal["default", True, False] + display_expand_data: Literal["default", True, False] + enable_cftimeindex: bool + file_cache_maxsize: int + keep_attrs: Literal["default", True, False] + warn_for_unclosed_files: bool + use_bottleneck: bool + + except ImportError: + from typing import TYPE_CHECKING, Any, Dict, Hashable + + if TYPE_CHECKING: + raise + else: + T_Options = Dict[Hashable, Any] + + +OPTIONS: T_Options = { + "arithmetic_join": "inner", + "cmap_divergent": "RdBu_r", + "cmap_sequential": "viridis", + "display_max_rows": 12, + "display_style": "html", + "display_width": 80, + "display_expand_attrs": "default", + "display_expand_coords": "default", + "display_expand_data_vars": "default", + "display_expand_data": "default", + "enable_cftimeindex": True, + "file_cache_maxsize": 128, + "keep_attrs": "default", + "warn_for_unclosed_files": False, + "use_bottleneck": True, } _JOIN_OPTIONS = frozenset(["inner", "outer", "left", "right", "exact"]) @@ -44,19 +97,19 @@ def _positive_integer(value): _VALIDATORS = { - ARITHMETIC_JOIN: _JOIN_OPTIONS.__contains__, - DISPLAY_MAX_ROWS: _positive_integer, - DISPLAY_STYLE: _DISPLAY_OPTIONS.__contains__, - DISPLAY_WIDTH: _positive_integer, - DISPLAY_EXPAND_ATTRS: lambda choice: choice in [True, False, "default"], - DISPLAY_EXPAND_COORDS: lambda choice: choice in [True, False, "default"], - DISPLAY_EXPAND_DATA_VARS: lambda choice: choice in [True, False, "default"], - DISPLAY_EXPAND_DATA: lambda choice: choice in [True, False, "default"], - ENABLE_CFTIMEINDEX: lambda value: isinstance(value, bool), - FILE_CACHE_MAXSIZE: _positive_integer, - KEEP_ATTRS: lambda choice: choice in [True, False, "default"], - WARN_FOR_UNCLOSED_FILES: lambda value: isinstance(value, bool), - USE_BOTTLENECK: lambda value: isinstance(value, bool), + "arithmetic_join": _JOIN_OPTIONS.__contains__, + "display_max_rows": _positive_integer, + "display_style": _DISPLAY_OPTIONS.__contains__, + "display_width": _positive_integer, + "display_expand_attrs": lambda choice: choice in [True, False, "default"], + "display_expand_coords": lambda choice: choice in [True, False, "default"], + "display_expand_data_vars": lambda choice: choice in [True, False, "default"], + "display_expand_data": lambda choice: choice in [True, False, "default"], + "enable_cftimeindex": lambda value: isinstance(value, bool), + "file_cache_maxsize": _positive_integer, + "keep_attrs": lambda choice: choice in [True, False, "default"], + "warn_for_unclosed_files": lambda value: isinstance(value, bool), + "use_bottleneck": lambda value: isinstance(value, bool), } @@ -75,8 +128,8 @@ def _warn_on_setting_enable_cftimeindex(enable_cftimeindex): _SETTERS = { - ENABLE_CFTIMEINDEX: _warn_on_setting_enable_cftimeindex, - FILE_CACHE_MAXSIZE: _set_file_cache_maxsize, + "enable_cftimeindex": _warn_on_setting_enable_cftimeindex, + "file_cache_maxsize": _set_file_cache_maxsize, } @@ -175,9 +228,9 @@ def __init__(self, **kwargs): f"argument name {k!r} is not in the set of valid options {set(OPTIONS)!r}" ) if k in _VALIDATORS and not _VALIDATORS[k](v): - if k == ARITHMETIC_JOIN: + if k == "arithmetic_join": expected = f"Expected one of {_JOIN_OPTIONS!r}" - elif k == DISPLAY_STYLE: + elif k == "display_style": expected = f"Expected one of {_DISPLAY_OPTIONS!r}" else: expected = "" From efe7e5f5c0bc94085ab69341364e06877794fd33 Mon Sep 17 00:00:00 2001 From: Giacomo Caria <44147817+gcaria@users.noreply.github.com> Date: Fri, 20 Aug 2021 00:16:19 +0200 Subject: [PATCH 08/20] Remove suggestion to install pytest-xdist in docs (#5713) --- doc/contributing.rst | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/doc/contributing.rst b/doc/contributing.rst index d73d18d5df7..779aba7536a 100644 --- a/doc/contributing.rst +++ b/doc/contributing.rst @@ -628,13 +628,7 @@ Or with one of the following constructs:: pytest xarray/tests/[test-module].py::[TestClass]::[test_method] Using `pytest-xdist `_, one can -speed up local testing on multicore machines. To use this feature, you will -need to install `pytest-xdist` via:: - - pip install pytest-xdist - - -Then, run pytest with the optional -n argument:: +speed up local testing on multicore machines, by running pytest with the optional -n argument:: pytest xarray -n 4 From 3cc3ab1f5504dcab07a0757343f9af1ff4a05ebb Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 19 Aug 2021 16:18:56 -0600 Subject: [PATCH 09/20] Refactor more groupby and resample tests (#5707) --- xarray/tests/test_dataarray.py | 739 ----------------------- xarray/tests/test_dataset.py | 242 -------- xarray/tests/test_groupby.py | 1020 +++++++++++++++++++++++++++++++- 3 files changed, 1019 insertions(+), 982 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 8ab8bc872da..c3223432b38 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -8,7 +8,6 @@ import pandas as pd import pytest from pandas.core.computation.ops import UndefinedVariableError -from pandas.tseries.frequencies import to_offset import xarray as xr from xarray import ( @@ -1263,25 +1262,6 @@ def test_selection_multiindex_from_level(self): expected = data.isel(xy=[0, 1]).unstack("xy").squeeze("y").drop_vars("y") assert_equal(actual, expected) - def test_stack_groupby_unsorted_coord(self): - data = [[0, 1], [2, 3]] - data_flat = [0, 1, 2, 3] - dims = ["x", "y"] - y_vals = [2, 3] - - arr = xr.DataArray(data, dims=dims, coords={"y": y_vals}) - actual1 = arr.stack(z=dims).groupby("z").first() - midx1 = pd.MultiIndex.from_product([[0, 1], [2, 3]], names=dims) - expected1 = xr.DataArray(data_flat, dims=["z"], coords={"z": midx1}) - xr.testing.assert_equal(actual1, expected1) - - # GH: 3287. Note that y coord values are not in sorted order. - arr = xr.DataArray(data, dims=dims, coords={"y": y_vals[::-1]}) - actual2 = arr.stack(z=dims).groupby("z").first() - midx2 = pd.MultiIndex.from_product([[0, 1], [3, 2]], names=dims) - expected2 = xr.DataArray(data_flat, dims=["z"], coords={"z": midx2}) - xr.testing.assert_equal(actual2, expected2) - def test_virtual_default_coords(self): array = DataArray(np.zeros((5,)), dims="x") expected = DataArray(range(5), dims="x", name="x") @@ -1446,12 +1426,6 @@ def test_assign_coords(self): expected = DataArray(10, {"c": 42}) assert_identical(actual, expected) - array = DataArray([1, 2, 3, 4], {"c": ("x", [0, 0, 1, 1])}, dims="x") - actual = array.groupby("c").assign_coords(d=lambda a: a.mean()) - expected = array.copy() - expected.coords["d"] = ("x", [1.5, 1.5, 3.5, 3.5]) - assert_identical(actual, expected) - with pytest.raises(ValueError, match=r"conflicting MultiIndex"): self.mda.assign_coords(level_1=range(4)) @@ -2614,719 +2588,6 @@ def test_fillna(self): with pytest.raises(ValueError, match=r"broadcast"): a.fillna([1, 2]) - fill_value = DataArray([0, 1], dims="y") - actual = a.fillna(fill_value) - expected = DataArray( - [[0, 1], [1, 1], [0, 1], [3, 3]], coords={"x": range(4)}, dims=("x", "y") - ) - assert_identical(expected, actual) - - expected = b.copy() - for target in [a, expected]: - target.coords["b"] = ("x", [0, 0, 1, 1]) - actual = a.groupby("b").fillna(DataArray([0, 2], dims="b")) - assert_identical(expected, actual) - - def test_groupby_iter(self): - for ((act_x, act_dv), (exp_x, exp_ds)) in zip( - self.dv.groupby("y"), self.ds.groupby("y") - ): - assert exp_x == act_x - assert_identical(exp_ds["foo"], act_dv) - for ((_, exp_dv), act_dv) in zip(self.dv.groupby("x"), self.dv): - assert_identical(exp_dv, act_dv) - - def make_groupby_example_array(self): - da = self.dv.copy() - da.coords["abc"] = ("y", np.array(["a"] * 9 + ["c"] + ["b"] * 10)) - da.coords["y"] = 20 + 100 * da["y"] - return da - - def test_groupby_properties(self): - grouped = self.make_groupby_example_array().groupby("abc") - expected_groups = {"a": range(0, 9), "c": [9], "b": range(10, 20)} - assert expected_groups.keys() == grouped.groups.keys() - for key in expected_groups: - assert_array_equal(expected_groups[key], grouped.groups[key]) - assert 3 == len(grouped) - - def test_groupby_map_identity(self): - expected = self.make_groupby_example_array() - idx = expected.coords["y"] - - def identity(x): - return x - - for g in ["x", "y", "abc", idx]: - for shortcut in [False, True]: - for squeeze in [False, True]: - grouped = expected.groupby(g, squeeze=squeeze) - actual = grouped.map(identity, shortcut=shortcut) - assert_identical(expected, actual) - - def test_groupby_sum(self): - array = self.make_groupby_example_array() - grouped = array.groupby("abc") - - expected_sum_all = Dataset( - { - "foo": Variable( - ["abc"], - np.array( - [ - self.x[:, :9].sum(), - self.x[:, 10:].sum(), - self.x[:, 9:10].sum(), - ] - ).T, - ), - "abc": Variable(["abc"], np.array(["a", "b", "c"])), - } - )["foo"] - assert_allclose(expected_sum_all, grouped.reduce(np.sum, dim=...)) - assert_allclose(expected_sum_all, grouped.sum(...)) - - expected = DataArray( - [ - array["y"].values[idx].sum() - for idx in [slice(9), slice(10, None), slice(9, 10)] - ], - [["a", "b", "c"]], - ["abc"], - ) - actual = array["y"].groupby("abc").map(np.sum) - assert_allclose(expected, actual) - actual = array["y"].groupby("abc").sum(...) - assert_allclose(expected, actual) - - expected_sum_axis1 = Dataset( - { - "foo": ( - ["x", "abc"], - np.array( - [ - self.x[:, :9].sum(1), - self.x[:, 10:].sum(1), - self.x[:, 9:10].sum(1), - ] - ).T, - ), - "abc": Variable(["abc"], np.array(["a", "b", "c"])), - } - )["foo"] - assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, "y")) - assert_allclose(expected_sum_axis1, grouped.sum("y")) - - def test_groupby_sum_default(self): - array = self.make_groupby_example_array() - grouped = array.groupby("abc") - - expected_sum_all = Dataset( - { - "foo": Variable( - ["x", "abc"], - np.array( - [ - self.x[:, :9].sum(axis=-1), - self.x[:, 10:].sum(axis=-1), - self.x[:, 9:10].sum(axis=-1), - ] - ).T, - ), - "abc": Variable(["abc"], np.array(["a", "b", "c"])), - } - )["foo"] - - assert_allclose(expected_sum_all, grouped.sum(dim="y")) - - def test_groupby_count(self): - array = DataArray( - [0, 0, np.nan, np.nan, 0, 0], - coords={"cat": ("x", ["a", "b", "b", "c", "c", "c"])}, - dims="x", - ) - actual = array.groupby("cat").count() - expected = DataArray([1, 1, 2], coords=[("cat", ["a", "b", "c"])]) - assert_identical(actual, expected) - - @pytest.mark.skip("needs to be fixed for shortcut=False, keep_attrs=False") - def test_groupby_reduce_attrs(self): - array = self.make_groupby_example_array() - array.attrs["foo"] = "bar" - - for shortcut in [True, False]: - for keep_attrs in [True, False]: - print(f"shortcut={shortcut}, keep_attrs={keep_attrs}") - actual = array.groupby("abc").reduce( - np.mean, keep_attrs=keep_attrs, shortcut=shortcut - ) - expected = array.groupby("abc").mean() - if keep_attrs: - expected.attrs["foo"] = "bar" - assert_identical(expected, actual) - - def test_groupby_map_center(self): - def center(x): - return x - np.mean(x) - - array = self.make_groupby_example_array() - grouped = array.groupby("abc") - - expected_ds = array.to_dataset() - exp_data = np.hstack( - [center(self.x[:, :9]), center(self.x[:, 9:10]), center(self.x[:, 10:])] - ) - expected_ds["foo"] = (["x", "y"], exp_data) - expected_centered = expected_ds["foo"] - assert_allclose(expected_centered, grouped.map(center)) - - def test_groupby_map_ndarray(self): - # regression test for #326 - array = self.make_groupby_example_array() - grouped = array.groupby("abc") - actual = grouped.map(np.asarray) - assert_equal(array, actual) - - def test_groupby_map_changes_metadata(self): - def change_metadata(x): - x.coords["x"] = x.coords["x"] * 2 - x.attrs["fruit"] = "lemon" - return x - - array = self.make_groupby_example_array() - grouped = array.groupby("abc") - actual = grouped.map(change_metadata) - expected = array.copy() - expected = change_metadata(expected) - assert_equal(expected, actual) - - def test_groupby_math(self): - array = self.make_groupby_example_array() - for squeeze in [True, False]: - grouped = array.groupby("x", squeeze=squeeze) - - expected = array + array.coords["x"] - actual = grouped + array.coords["x"] - assert_identical(expected, actual) - - actual = array.coords["x"] + grouped - assert_identical(expected, actual) - - ds = array.coords["x"].to_dataset(name="X") - expected = array + ds - actual = grouped + ds - assert_identical(expected, actual) - - actual = ds + grouped - assert_identical(expected, actual) - - grouped = array.groupby("abc") - expected_agg = (grouped.mean(...) - np.arange(3)).rename(None) - actual = grouped - DataArray(range(3), [("abc", ["a", "b", "c"])]) - actual_agg = actual.groupby("abc").mean(...) - assert_allclose(expected_agg, actual_agg) - - with pytest.raises(TypeError, match=r"only support binary ops"): - grouped + 1 - with pytest.raises(TypeError, match=r"only support binary ops"): - grouped + grouped - with pytest.raises(TypeError, match=r"in-place operations"): - array += grouped - - def test_groupby_math_not_aligned(self): - array = DataArray( - range(4), {"b": ("x", [0, 0, 1, 1]), "x": [0, 1, 2, 3]}, dims="x" - ) - other = DataArray([10], coords={"b": [0]}, dims="b") - actual = array.groupby("b") + other - expected = DataArray([10, 11, np.nan, np.nan], array.coords) - assert_identical(expected, actual) - - other = DataArray([10], coords={"c": 123, "b": [0]}, dims="b") - actual = array.groupby("b") + other - expected.coords["c"] = (["x"], [123] * 2 + [np.nan] * 2) - assert_identical(expected, actual) - - other = Dataset({"a": ("b", [10])}, {"b": [0]}) - actual = array.groupby("b") + other - expected = Dataset({"a": ("x", [10, 11, np.nan, np.nan])}, array.coords) - assert_identical(expected, actual) - - def test_groupby_restore_dim_order(self): - array = DataArray( - np.random.randn(5, 3), - coords={"a": ("x", range(5)), "b": ("y", range(3))}, - dims=["x", "y"], - ) - for by, expected_dims in [ - ("x", ("x", "y")), - ("y", ("x", "y")), - ("a", ("a", "y")), - ("b", ("x", "b")), - ]: - result = array.groupby(by).map(lambda x: x.squeeze()) - assert result.dims == expected_dims - - def test_groupby_restore_coord_dims(self): - array = DataArray( - np.random.randn(5, 3), - coords={ - "a": ("x", range(5)), - "b": ("y", range(3)), - "c": (("x", "y"), np.random.randn(5, 3)), - }, - dims=["x", "y"], - ) - - for by, expected_dims in [ - ("x", ("x", "y")), - ("y", ("x", "y")), - ("a", ("a", "y")), - ("b", ("x", "b")), - ]: - result = array.groupby(by, restore_coord_dims=True).map( - lambda x: x.squeeze() - )["c"] - assert result.dims == expected_dims - - def test_groupby_first_and_last(self): - array = DataArray([1, 2, 3, 4, 5], dims="x") - by = DataArray(["a"] * 2 + ["b"] * 3, dims="x", name="ab") - - expected = DataArray([1, 3], [("ab", ["a", "b"])]) - actual = array.groupby(by).first() - assert_identical(expected, actual) - - expected = DataArray([2, 5], [("ab", ["a", "b"])]) - actual = array.groupby(by).last() - assert_identical(expected, actual) - - array = DataArray(np.random.randn(5, 3), dims=["x", "y"]) - expected = DataArray(array[[0, 2]], {"ab": ["a", "b"]}, ["ab", "y"]) - actual = array.groupby(by).first() - assert_identical(expected, actual) - - actual = array.groupby("x").first() - expected = array # should be a no-op - assert_identical(expected, actual) - - def make_groupby_multidim_example_array(self): - return DataArray( - [[[0, 1], [2, 3]], [[5, 10], [15, 20]]], - coords={ - "lon": (["ny", "nx"], [[30, 40], [40, 50]]), - "lat": (["ny", "nx"], [[10, 10], [20, 20]]), - }, - dims=["time", "ny", "nx"], - ) - - def test_groupby_multidim(self): - array = self.make_groupby_multidim_example_array() - for dim, expected_sum in [ - ("lon", DataArray([5, 28, 23], coords=[("lon", [30.0, 40.0, 50.0])])), - ("lat", DataArray([16, 40], coords=[("lat", [10.0, 20.0])])), - ]: - actual_sum = array.groupby(dim).sum(...) - assert_identical(expected_sum, actual_sum) - - def test_groupby_multidim_map(self): - array = self.make_groupby_multidim_example_array() - actual = array.groupby("lon").map(lambda x: x - x.mean()) - expected = DataArray( - [[[-2.5, -6.0], [-5.0, -8.5]], [[2.5, 3.0], [8.0, 8.5]]], - coords=array.coords, - dims=array.dims, - ) - assert_identical(expected, actual) - - def test_groupby_bins(self): - array = DataArray(np.arange(4), dims="dim_0") - # the first value should not be part of any group ("right" binning) - array[0] = 99 - # bins follow conventions for pandas.cut - # http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html - bins = [0, 1.5, 5] - bin_coords = pd.cut(array["dim_0"], bins).categories - expected = DataArray( - [1, 5], dims="dim_0_bins", coords={"dim_0_bins": bin_coords} - ) - # the problem with this is that it overwrites the dimensions of array! - # actual = array.groupby('dim_0', bins=bins).sum() - actual = array.groupby_bins("dim_0", bins).map(lambda x: x.sum()) - assert_identical(expected, actual) - # make sure original array dims are unchanged - assert len(array.dim_0) == 4 - - def test_groupby_bins_empty(self): - array = DataArray(np.arange(4), [("x", range(4))]) - # one of these bins will be empty - bins = [0, 4, 5] - bin_coords = pd.cut(array["x"], bins).categories - actual = array.groupby_bins("x", bins).sum() - expected = DataArray([6, np.nan], dims="x_bins", coords={"x_bins": bin_coords}) - assert_identical(expected, actual) - # make sure original array is unchanged - # (was a problem in earlier versions) - assert len(array.x) == 4 - - def test_groupby_bins_multidim(self): - array = self.make_groupby_multidim_example_array() - bins = [0, 15, 20] - bin_coords = pd.cut(array["lat"].values.flat, bins).categories - expected = DataArray([16, 40], dims="lat_bins", coords={"lat_bins": bin_coords}) - actual = array.groupby_bins("lat", bins).map(lambda x: x.sum()) - assert_identical(expected, actual) - # modify the array coordinates to be non-monotonic after unstacking - array["lat"].data = np.array([[10.0, 20.0], [20.0, 10.0]]) - expected = DataArray([28, 28], dims="lat_bins", coords={"lat_bins": bin_coords}) - actual = array.groupby_bins("lat", bins).map(lambda x: x.sum()) - assert_identical(expected, actual) - - def test_groupby_bins_sort(self): - data = xr.DataArray( - np.arange(100), dims="x", coords={"x": np.linspace(-100, 100, num=100)} - ) - binned_mean = data.groupby_bins("x", bins=11).mean() - assert binned_mean.to_index().is_monotonic - - def test_resample(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) - array = DataArray(np.arange(10), [("time", times)]) - - actual = array.resample(time="24H").mean() - expected = DataArray(array.to_series().resample("24H").mean()) - assert_identical(expected, actual) - - actual = array.resample(time="24H").reduce(np.mean) - assert_identical(expected, actual) - - # Our use of `loffset` may change if we align our API with pandas' changes. - # ref https://github.com/pydata/xarray/pull/4537 - actual = array.resample(time="24H", loffset="-12H").mean() - expected_ = array.to_series().resample("24H").mean() - expected_.index += to_offset("-12H") - expected = DataArray.from_series(expected_) - assert_identical(actual, expected) - - with pytest.raises(ValueError, match=r"index must be monotonic"): - array[[2, 0, 1]].resample(time="1D") - - def test_da_resample_func_args(self): - def func(arg1, arg2, arg3=0.0): - return arg1.mean("time") + arg2 + arg3 - - times = pd.date_range("2000", periods=3, freq="D") - da = xr.DataArray([1.0, 1.0, 1.0], coords=[times], dims=["time"]) - expected = xr.DataArray([3.0, 3.0, 3.0], coords=[times], dims=["time"]) - actual = da.resample(time="D").map(func, args=(1.0,), arg3=1.0) - assert_identical(actual, expected) - - def test_resample_first(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) - array = DataArray(np.arange(10), [("time", times)]) - - actual = array.resample(time="1D").first() - expected = DataArray([0, 4, 8], [("time", times[::4])]) - assert_identical(expected, actual) - - # verify that labels don't use the first value - actual = array.resample(time="24H").first() - expected = DataArray(array.to_series().resample("24H").first()) - assert_identical(expected, actual) - - # missing values - array = array.astype(float) - array[:2] = np.nan - actual = array.resample(time="1D").first() - expected = DataArray([2, 4, 8], [("time", times[::4])]) - assert_identical(expected, actual) - - actual = array.resample(time="1D").first(skipna=False) - expected = DataArray([np.nan, 4, 8], [("time", times[::4])]) - assert_identical(expected, actual) - - # regression test for http://stackoverflow.com/questions/33158558/ - array = Dataset({"time": times})["time"] - actual = array.resample(time="1D").last() - expected_times = pd.to_datetime( - ["2000-01-01T18", "2000-01-02T18", "2000-01-03T06"] - ) - expected = DataArray(expected_times, [("time", times[::4])], name="time") - assert_identical(expected, actual) - - def test_resample_bad_resample_dim(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) - array = DataArray(np.arange(10), [("__resample_dim__", times)]) - with pytest.raises(ValueError, match=r"Proxy resampling dimension"): - array.resample(**{"__resample_dim__": "1D"}).first() - - @requires_scipy - def test_resample_drop_nondim_coords(self): - xs = np.arange(6) - ys = np.arange(3) - times = pd.date_range("2000-01-01", freq="6H", periods=5) - data = np.tile(np.arange(5), (6, 3, 1)) - xx, yy = np.meshgrid(xs * 5, ys * 2.5) - tt = np.arange(len(times), dtype=int) - array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) - xcoord = DataArray(xx.T, {"x": xs, "y": ys}, ("x", "y")) - ycoord = DataArray(yy.T, {"x": xs, "y": ys}, ("x", "y")) - tcoord = DataArray(tt, {"time": times}, ("time",)) - ds = Dataset({"data": array, "xc": xcoord, "yc": ycoord, "tc": tcoord}) - ds = ds.set_coords(["xc", "yc", "tc"]) - - # Select the data now, with the auxiliary coordinates in place - array = ds["data"] - - # Re-sample - actual = array.resample(time="12H", restore_coord_dims=True).mean("time") - assert "tc" not in actual.coords - - # Up-sample - filling - actual = array.resample(time="1H", restore_coord_dims=True).ffill() - assert "tc" not in actual.coords - - # Up-sample - interpolation - actual = array.resample(time="1H", restore_coord_dims=True).interpolate( - "linear" - ) - assert "tc" not in actual.coords - - def test_resample_keep_attrs(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) - array = DataArray(np.ones(10), [("time", times)]) - array.attrs["meta"] = "data" - - result = array.resample(time="1D").mean(keep_attrs=True) - expected = DataArray([1, 1, 1], [("time", times[::4])], attrs=array.attrs) - assert_identical(result, expected) - - with pytest.warns( - UserWarning, match="Passing ``keep_attrs`` to ``resample`` has no effect." - ): - array.resample(time="1D", keep_attrs=True) - - def test_resample_skipna(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) - array = DataArray(np.ones(10), [("time", times)]) - array[1] = np.nan - - result = array.resample(time="1D").mean(skipna=False) - expected = DataArray([np.nan, 1, 1], [("time", times[::4])]) - assert_identical(result, expected) - - def test_upsample(self): - times = pd.date_range("2000-01-01", freq="6H", periods=5) - array = DataArray(np.arange(5), [("time", times)]) - - # Forward-fill - actual = array.resample(time="3H").ffill() - expected = DataArray(array.to_series().resample("3H").ffill()) - assert_identical(expected, actual) - - # Backward-fill - actual = array.resample(time="3H").bfill() - expected = DataArray(array.to_series().resample("3H").bfill()) - assert_identical(expected, actual) - - # As frequency - actual = array.resample(time="3H").asfreq() - expected = DataArray(array.to_series().resample("3H").asfreq()) - assert_identical(expected, actual) - - # Pad - actual = array.resample(time="3H").pad() - expected = DataArray(array.to_series().resample("3H").pad()) - assert_identical(expected, actual) - - # Nearest - rs = array.resample(time="3H") - actual = rs.nearest() - new_times = rs._full_index - expected = DataArray(array.reindex(time=new_times, method="nearest")) - assert_identical(expected, actual) - - def test_upsample_nd(self): - # Same as before, but now we try on multi-dimensional DataArrays. - xs = np.arange(6) - ys = np.arange(3) - times = pd.date_range("2000-01-01", freq="6H", periods=5) - data = np.tile(np.arange(5), (6, 3, 1)) - array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) - - # Forward-fill - actual = array.resample(time="3H").ffill() - expected_data = np.repeat(data, 2, axis=-1) - expected_times = times.to_series().resample("3H").asfreq().index - expected_data = expected_data[..., : len(expected_times)] - expected = DataArray( - expected_data, - {"time": expected_times, "x": xs, "y": ys}, - ("x", "y", "time"), - ) - assert_identical(expected, actual) - - # Backward-fill - actual = array.resample(time="3H").ffill() - expected_data = np.repeat(np.flipud(data.T).T, 2, axis=-1) - expected_data = np.flipud(expected_data.T).T - expected_times = times.to_series().resample("3H").asfreq().index - expected_data = expected_data[..., : len(expected_times)] - expected = DataArray( - expected_data, - {"time": expected_times, "x": xs, "y": ys}, - ("x", "y", "time"), - ) - assert_identical(expected, actual) - - # As frequency - actual = array.resample(time="3H").asfreq() - expected_data = np.repeat(data, 2, axis=-1).astype(float)[..., :-1] - expected_data[..., 1::2] = np.nan - expected_times = times.to_series().resample("3H").asfreq().index - expected = DataArray( - expected_data, - {"time": expected_times, "x": xs, "y": ys}, - ("x", "y", "time"), - ) - assert_identical(expected, actual) - - # Pad - actual = array.resample(time="3H").pad() - expected_data = np.repeat(data, 2, axis=-1) - expected_data[..., 1::2] = expected_data[..., ::2] - expected_data = expected_data[..., :-1] - expected_times = times.to_series().resample("3H").asfreq().index - expected = DataArray( - expected_data, - {"time": expected_times, "x": xs, "y": ys}, - ("x", "y", "time"), - ) - assert_identical(expected, actual) - - def test_upsample_tolerance(self): - # Test tolerance keyword for upsample methods bfill, pad, nearest - times = pd.date_range("2000-01-01", freq="1D", periods=2) - times_upsampled = pd.date_range("2000-01-01", freq="6H", periods=5) - array = DataArray(np.arange(2), [("time", times)]) - - # Forward fill - actual = array.resample(time="6H").ffill(tolerance="12H") - expected = DataArray([0.0, 0.0, 0.0, np.nan, 1.0], [("time", times_upsampled)]) - assert_identical(expected, actual) - - # Backward fill - actual = array.resample(time="6H").bfill(tolerance="12H") - expected = DataArray([0.0, np.nan, 1.0, 1.0, 1.0], [("time", times_upsampled)]) - assert_identical(expected, actual) - - # Nearest - actual = array.resample(time="6H").nearest(tolerance="6H") - expected = DataArray([0, 0, np.nan, 1, 1], [("time", times_upsampled)]) - assert_identical(expected, actual) - - @requires_scipy - def test_upsample_interpolate(self): - from scipy.interpolate import interp1d - - xs = np.arange(6) - ys = np.arange(3) - times = pd.date_range("2000-01-01", freq="6H", periods=5) - - z = np.arange(5) ** 2 - data = np.tile(z, (6, 3, 1)) - array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) - - expected_times = times.to_series().resample("1H").asfreq().index - # Split the times into equal sub-intervals to simulate the 6 hour - # to 1 hour up-sampling - new_times_idx = np.linspace(0, len(times) - 1, len(times) * 5) - for kind in ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]: - actual = array.resample(time="1H").interpolate(kind) - f = interp1d( - np.arange(len(times)), - data, - kind=kind, - axis=-1, - bounds_error=True, - assume_sorted=True, - ) - expected_data = f(new_times_idx) - expected = DataArray( - expected_data, - {"time": expected_times, "x": xs, "y": ys}, - ("x", "y", "time"), - ) - # Use AllClose because there are some small differences in how - # we upsample timeseries versus the integer indexing as I've - # done here due to floating point arithmetic - assert_allclose(expected, actual, rtol=1e-16) - - @requires_scipy - def test_upsample_interpolate_bug_2197(self): - dates = pd.date_range("2007-02-01", "2007-03-01", freq="D") - da = xr.DataArray(np.arange(len(dates)), [("time", dates)]) - result = da.resample(time="M").interpolate("linear") - expected_times = np.array( - [np.datetime64("2007-02-28"), np.datetime64("2007-03-31")] - ) - expected = xr.DataArray([27.0, np.nan], [("time", expected_times)]) - assert_equal(result, expected) - - @requires_scipy - def test_upsample_interpolate_regression_1605(self): - dates = pd.date_range("2016-01-01", "2016-03-31", freq="1D") - expected = xr.DataArray( - np.random.random((len(dates), 2, 3)), - dims=("time", "x", "y"), - coords={"time": dates}, - ) - actual = expected.resample(time="1D").interpolate("linear") - assert_allclose(actual, expected, rtol=1e-16) - - @requires_dask - @requires_scipy - @pytest.mark.parametrize("chunked_time", [True, False]) - def test_upsample_interpolate_dask(self, chunked_time): - from scipy.interpolate import interp1d - - xs = np.arange(6) - ys = np.arange(3) - times = pd.date_range("2000-01-01", freq="6H", periods=5) - - z = np.arange(5) ** 2 - data = np.tile(z, (6, 3, 1)) - array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) - chunks = {"x": 2, "y": 1} - if chunked_time: - chunks["time"] = 3 - - expected_times = times.to_series().resample("1H").asfreq().index - # Split the times into equal sub-intervals to simulate the 6 hour - # to 1 hour up-sampling - new_times_idx = np.linspace(0, len(times) - 1, len(times) * 5) - for kind in ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]: - actual = array.chunk(chunks).resample(time="1H").interpolate(kind) - actual = actual.compute() - f = interp1d( - np.arange(len(times)), - data, - kind=kind, - axis=-1, - bounds_error=True, - assume_sorted=True, - ) - expected_data = f(new_times_idx) - expected = DataArray( - expected_data, - {"time": expected_times, "x": xs, "y": ys}, - ("x", "y", "time"), - ) - # Use AllClose because there are some small differences in how - # we upsample timeseries versus the integer indexing as I've - # done here due to floating point arithmetic - assert_allclose(expected, actual, rtol=1e-16) - def test_align(self): array = DataArray( np.random.random((6, 8)), coords={"x": list("abcdef")}, dims=["x", "y"] diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 1256a44ad81..7b17eae89c8 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -10,7 +10,6 @@ import pytest from pandas.core.computation.ops import UndefinedVariableError from pandas.core.indexes.datetimes import DatetimeIndex -from pandas.tseries.frequencies import to_offset import xarray as xr from xarray import ( @@ -3614,19 +3613,6 @@ def test_assign(self): expected = Dataset({"y": ("x", [0, 1, 4])}, {"z": 2, "x": [0, 1, 2]}) assert_identical(actual, expected) - ds = Dataset({"a": ("x", range(3))}, {"b": ("x", ["A"] * 2 + ["B"])}) - actual = ds.groupby("b").assign(c=lambda ds: 2 * ds.a) - expected = ds.merge({"c": ("x", [0, 2, 4])}) - assert_identical(actual, expected) - - actual = ds.groupby("b").assign(c=lambda ds: ds.a.sum()) - expected = ds.merge({"c": ("x", [1, 1, 2])}) - assert_identical(actual, expected) - - actual = ds.groupby("b").assign_coords(c=lambda ds: ds.a.sum()) - expected = expected.set_coords("c") - assert_identical(actual, expected) - def test_assign_coords(self): ds = Dataset() @@ -3758,200 +3744,6 @@ def test_squeeze_drop(self): selected = data.squeeze(drop=True) assert_identical(data, selected) - def test_resample_and_first(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) - ds = Dataset( - { - "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), - "bar": ("time", np.random.randn(10), {"meta": "data"}), - "time": times, - } - ) - - actual = ds.resample(time="1D").first(keep_attrs=True) - expected = ds.isel(time=[0, 4, 8]) - assert_identical(expected, actual) - - # upsampling - expected_time = pd.date_range("2000-01-01", freq="3H", periods=19) - expected = ds.reindex(time=expected_time) - actual = ds.resample(time="3H") - for how in ["mean", "sum", "first", "last"]: - method = getattr(actual, how) - result = method() - assert_equal(expected, result) - for method in [np.mean]: - result = actual.reduce(method) - assert_equal(expected, result) - - def test_resample_min_count(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) - ds = Dataset( - { - "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), - "bar": ("time", np.random.randn(10), {"meta": "data"}), - "time": times, - } - ) - # inject nan - ds["foo"] = xr.where(ds["foo"] > 2.0, np.nan, ds["foo"]) - - actual = ds.resample(time="1D").sum(min_count=1) - expected = xr.concat( - [ - ds.isel(time=slice(i * 4, (i + 1) * 4)).sum("time", min_count=1) - for i in range(3) - ], - dim=actual["time"], - ) - assert_equal(expected, actual) - - def test_resample_by_mean_with_keep_attrs(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) - ds = Dataset( - { - "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), - "bar": ("time", np.random.randn(10), {"meta": "data"}), - "time": times, - } - ) - ds.attrs["dsmeta"] = "dsdata" - - resampled_ds = ds.resample(time="1D").mean(keep_attrs=True) - actual = resampled_ds["bar"].attrs - expected = ds["bar"].attrs - assert expected == actual - - actual = resampled_ds.attrs - expected = ds.attrs - assert expected == actual - - with pytest.warns( - UserWarning, match="Passing ``keep_attrs`` to ``resample`` has no effect." - ): - ds.resample(time="1D", keep_attrs=True) - - def test_resample_loffset(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) - ds = Dataset( - { - "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), - "bar": ("time", np.random.randn(10), {"meta": "data"}), - "time": times, - } - ) - ds.attrs["dsmeta"] = "dsdata" - - # Our use of `loffset` may change if we align our API with pandas' changes. - # ref https://github.com/pydata/xarray/pull/4537 - actual = ds.resample(time="24H", loffset="-12H").mean().bar - expected_ = ds.bar.to_series().resample("24H").mean() - expected_.index += to_offset("-12H") - expected = DataArray.from_series(expected_) - assert_allclose(actual, expected) - - def test_resample_by_mean_discarding_attrs(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) - ds = Dataset( - { - "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), - "bar": ("time", np.random.randn(10), {"meta": "data"}), - "time": times, - } - ) - ds.attrs["dsmeta"] = "dsdata" - - resampled_ds = ds.resample(time="1D").mean(keep_attrs=False) - - assert resampled_ds["bar"].attrs == {} - assert resampled_ds.attrs == {} - - def test_resample_by_last_discarding_attrs(self): - times = pd.date_range("2000-01-01", freq="6H", periods=10) - ds = Dataset( - { - "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), - "bar": ("time", np.random.randn(10), {"meta": "data"}), - "time": times, - } - ) - ds.attrs["dsmeta"] = "dsdata" - - resampled_ds = ds.resample(time="1D").last(keep_attrs=False) - - assert resampled_ds["bar"].attrs == {} - assert resampled_ds.attrs == {} - - @requires_scipy - def test_resample_drop_nondim_coords(self): - xs = np.arange(6) - ys = np.arange(3) - times = pd.date_range("2000-01-01", freq="6H", periods=5) - data = np.tile(np.arange(5), (6, 3, 1)) - xx, yy = np.meshgrid(xs * 5, ys * 2.5) - tt = np.arange(len(times), dtype=int) - array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) - xcoord = DataArray(xx.T, {"x": xs, "y": ys}, ("x", "y")) - ycoord = DataArray(yy.T, {"x": xs, "y": ys}, ("x", "y")) - tcoord = DataArray(tt, {"time": times}, ("time",)) - ds = Dataset({"data": array, "xc": xcoord, "yc": ycoord, "tc": tcoord}) - ds = ds.set_coords(["xc", "yc", "tc"]) - - # Re-sample - actual = ds.resample(time="12H").mean("time") - assert "tc" not in actual.coords - - # Up-sample - filling - actual = ds.resample(time="1H").ffill() - assert "tc" not in actual.coords - - # Up-sample - interpolation - actual = ds.resample(time="1H").interpolate("linear") - assert "tc" not in actual.coords - - def test_resample_old_api(self): - - times = pd.date_range("2000-01-01", freq="6H", periods=10) - ds = Dataset( - { - "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), - "bar": ("time", np.random.randn(10), {"meta": "data"}), - "time": times, - } - ) - - with pytest.raises(TypeError, match=r"resample\(\) no longer supports"): - ds.resample("1D", "time") - - with pytest.raises(TypeError, match=r"resample\(\) no longer supports"): - ds.resample("1D", dim="time", how="mean") - - with pytest.raises(TypeError, match=r"resample\(\) no longer supports"): - ds.resample("1D", dim="time") - - def test_resample_ds_da_are_the_same(self): - time = pd.date_range("2000-01-01", freq="6H", periods=365 * 4) - ds = xr.Dataset( - { - "foo": (("time", "x"), np.random.randn(365 * 4, 5)), - "time": time, - "x": np.arange(5), - } - ) - assert_identical( - ds.resample(time="M").mean()["foo"], ds.foo.resample(time="M").mean() - ) - - def test_ds_resample_apply_func_args(self): - def func(arg1, arg2, arg3=0.0): - return arg1.mean("time") + arg2 + arg3 - - times = pd.date_range("2000", freq="D", periods=3) - ds = xr.Dataset({"foo": ("time", [1.0, 1.0, 1.0]), "time": times}) - expected = xr.Dataset({"foo": ("time", [3.0, 3.0, 3.0]), "time": times}) - actual = ds.resample(time="D").map(func, args=(1.0,), arg3=1.0) - assert_identical(expected, actual) - def test_to_array(self): ds = Dataset( {"a": 1, "b": ("x", [1, 2, 3])}, @@ -4451,24 +4243,6 @@ def test_fillna(self): expected = ds.assign_coords(c=42) assert_identical(expected, result) - # groupby - expected = Dataset({"a": ("x", range(4))}, {"x": [0, 1, 2, 3]}) - for target in [ds, expected]: - target.coords["b"] = ("x", [0, 0, 1, 1]) - actual = ds.groupby("b").fillna(DataArray([0, 2], dims="b")) - assert_identical(expected, actual) - - actual = ds.groupby("b").fillna(Dataset({"a": ("b", [0, 2])})) - assert_identical(expected, actual) - - # attrs with groupby - ds.attrs["attr"] = "ds" - ds.a.attrs["attr"] = "da" - actual = ds.groupby("b").fillna(Dataset({"a": ("b", [0, 2])})) - assert actual.attrs == ds.attrs - assert actual.a.name == "a" - assert actual.a.attrs == ds.a.attrs - da = DataArray(range(5), name="a", attrs={"attr": "da"}) actual = da.fillna(1) assert actual.name == "a" @@ -4528,22 +4302,6 @@ def test_where(self): actual = ds.where(ds > 0) assert_identical(expected, actual) - # groupby - ds = Dataset({"a": ("x", range(5))}, {"c": ("x", [0, 0, 1, 1, 1])}) - cond = Dataset({"a": ("c", [True, False])}) - expected = ds.copy(deep=True) - expected["a"].values = [0, 1] + [np.nan] * 3 - actual = ds.groupby("c").where(cond) - assert_identical(expected, actual) - - # attrs with groupby - ds.attrs["attr"] = "ds" - ds.a.attrs["attr"] = "da" - actual = ds.groupby("c").where(cond) - assert actual.attrs == ds.attrs - assert actual.a.name == "a" - assert actual.a.attrs == ds.a.attrs - # attrs da = DataArray(range(5), name="a", attrs={"attr": "da"}) actual = da.where(da.values > 1) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index b2510141d78..ee77865dd24 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -1,12 +1,21 @@ import numpy as np import pandas as pd import pytest +from pandas.tseries.frequencies import to_offset import xarray as xr from xarray import DataArray, Dataset, Variable from xarray.core.groupby import _consolidate_slices -from . import assert_allclose, assert_equal, assert_identical, create_test_data +from . import ( + assert_allclose, + assert_array_equal, + assert_equal, + assert_identical, + create_test_data, + requires_dask, + requires_scipy, +) @pytest.fixture @@ -741,4 +750,1013 @@ def test_groupby_dataset_order(): # .assertEqual(all_vars, all_vars_ref) +def test_groupby_dataset_fillna(): + + ds = Dataset({"a": ("x", [np.nan, 1, np.nan, 3])}, {"x": [0, 1, 2, 3]}) + expected = Dataset({"a": ("x", range(4))}, {"x": [0, 1, 2, 3]}) + for target in [ds, expected]: + target.coords["b"] = ("x", [0, 0, 1, 1]) + actual = ds.groupby("b").fillna(DataArray([0, 2], dims="b")) + assert_identical(expected, actual) + + actual = ds.groupby("b").fillna(Dataset({"a": ("b", [0, 2])})) + assert_identical(expected, actual) + + # attrs with groupby + ds.attrs["attr"] = "ds" + ds.a.attrs["attr"] = "da" + actual = ds.groupby("b").fillna(Dataset({"a": ("b", [0, 2])})) + assert actual.attrs == ds.attrs + assert actual.a.name == "a" + assert actual.a.attrs == ds.a.attrs + + +def test_groupby_dataset_where(): + # groupby + ds = Dataset({"a": ("x", range(5))}, {"c": ("x", [0, 0, 1, 1, 1])}) + cond = Dataset({"a": ("c", [True, False])}) + expected = ds.copy(deep=True) + expected["a"].values = [0, 1] + [np.nan] * 3 + actual = ds.groupby("c").where(cond) + assert_identical(expected, actual) + + # attrs with groupby + ds.attrs["attr"] = "ds" + ds.a.attrs["attr"] = "da" + actual = ds.groupby("c").where(cond) + assert actual.attrs == ds.attrs + assert actual.a.name == "a" + assert actual.a.attrs == ds.a.attrs + + +def test_groupby_dataset_assign(): + ds = Dataset({"a": ("x", range(3))}, {"b": ("x", ["A"] * 2 + ["B"])}) + actual = ds.groupby("b").assign(c=lambda ds: 2 * ds.a) + expected = ds.merge({"c": ("x", [0, 2, 4])}) + assert_identical(actual, expected) + + actual = ds.groupby("b").assign(c=lambda ds: ds.a.sum()) + expected = ds.merge({"c": ("x", [1, 1, 2])}) + assert_identical(actual, expected) + + actual = ds.groupby("b").assign_coords(c=lambda ds: ds.a.sum()) + expected = expected.set_coords("c") + assert_identical(actual, expected) + + +class TestDataArrayGroupBy: + @pytest.fixture(autouse=True) + def setup(self): + self.attrs = {"attr1": "value1", "attr2": 2929} + self.x = np.random.random((10, 20)) + self.v = Variable(["x", "y"], self.x) + self.va = Variable(["x", "y"], self.x, self.attrs) + self.ds = Dataset({"foo": self.v}) + self.dv = self.ds["foo"] + + self.mindex = pd.MultiIndex.from_product( + [["a", "b"], [1, 2]], names=("level_1", "level_2") + ) + self.mda = DataArray([0, 1, 2, 3], coords={"x": self.mindex}, dims="x") + + self.da = self.dv.copy() + self.da.coords["abc"] = ("y", np.array(["a"] * 9 + ["c"] + ["b"] * 10)) + self.da.coords["y"] = 20 + 100 * self.da["y"] + + def test_stack_groupby_unsorted_coord(self): + data = [[0, 1], [2, 3]] + data_flat = [0, 1, 2, 3] + dims = ["x", "y"] + y_vals = [2, 3] + + arr = xr.DataArray(data, dims=dims, coords={"y": y_vals}) + actual1 = arr.stack(z=dims).groupby("z").first() + midx1 = pd.MultiIndex.from_product([[0, 1], [2, 3]], names=dims) + expected1 = xr.DataArray(data_flat, dims=["z"], coords={"z": midx1}) + assert_equal(actual1, expected1) + + # GH: 3287. Note that y coord values are not in sorted order. + arr = xr.DataArray(data, dims=dims, coords={"y": y_vals[::-1]}) + actual2 = arr.stack(z=dims).groupby("z").first() + midx2 = pd.MultiIndex.from_product([[0, 1], [3, 2]], names=dims) + expected2 = xr.DataArray(data_flat, dims=["z"], coords={"z": midx2}) + assert_equal(actual2, expected2) + + def test_groupby_iter(self): + for ((act_x, act_dv), (exp_x, exp_ds)) in zip( + self.dv.groupby("y"), self.ds.groupby("y") + ): + assert exp_x == act_x + assert_identical(exp_ds["foo"], act_dv) + for ((_, exp_dv), act_dv) in zip(self.dv.groupby("x"), self.dv): + assert_identical(exp_dv, act_dv) + + def test_groupby_properties(self): + grouped = self.da.groupby("abc") + expected_groups = {"a": range(0, 9), "c": [9], "b": range(10, 20)} + assert expected_groups.keys() == grouped.groups.keys() + for key in expected_groups: + assert_array_equal(expected_groups[key], grouped.groups[key]) + assert 3 == len(grouped) + + def test_groupby_map_identity(self): + expected = self.da + idx = expected.coords["y"] + + def identity(x): + return x + + for g in ["x", "y", "abc", idx]: + for shortcut in [False, True]: + for squeeze in [False, True]: + grouped = expected.groupby(g, squeeze=squeeze) + actual = grouped.map(identity, shortcut=shortcut) + assert_identical(expected, actual) + + def test_groupby_sum(self): + array = self.da + grouped = array.groupby("abc") + + expected_sum_all = Dataset( + { + "foo": Variable( + ["abc"], + np.array( + [ + self.x[:, :9].sum(), + self.x[:, 10:].sum(), + self.x[:, 9:10].sum(), + ] + ).T, + ), + "abc": Variable(["abc"], np.array(["a", "b", "c"])), + } + )["foo"] + assert_allclose(expected_sum_all, grouped.reduce(np.sum, dim=...)) + assert_allclose(expected_sum_all, grouped.sum(...)) + + expected = DataArray( + [ + array["y"].values[idx].sum() + for idx in [slice(9), slice(10, None), slice(9, 10)] + ], + [["a", "b", "c"]], + ["abc"], + ) + actual = array["y"].groupby("abc").map(np.sum) + assert_allclose(expected, actual) + actual = array["y"].groupby("abc").sum(...) + assert_allclose(expected, actual) + + expected_sum_axis1 = Dataset( + { + "foo": ( + ["x", "abc"], + np.array( + [ + self.x[:, :9].sum(1), + self.x[:, 10:].sum(1), + self.x[:, 9:10].sum(1), + ] + ).T, + ), + "abc": Variable(["abc"], np.array(["a", "b", "c"])), + } + )["foo"] + assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, "y")) + assert_allclose(expected_sum_axis1, grouped.sum("y")) + + def test_groupby_sum_default(self): + array = self.da + grouped = array.groupby("abc") + + expected_sum_all = Dataset( + { + "foo": Variable( + ["x", "abc"], + np.array( + [ + self.x[:, :9].sum(axis=-1), + self.x[:, 10:].sum(axis=-1), + self.x[:, 9:10].sum(axis=-1), + ] + ).T, + ), + "abc": Variable(["abc"], np.array(["a", "b", "c"])), + } + )["foo"] + + assert_allclose(expected_sum_all, grouped.sum(dim="y")) + + def test_groupby_count(self): + array = DataArray( + [0, 0, np.nan, np.nan, 0, 0], + coords={"cat": ("x", ["a", "b", "b", "c", "c", "c"])}, + dims="x", + ) + actual = array.groupby("cat").count() + expected = DataArray([1, 1, 2], coords=[("cat", ["a", "b", "c"])]) + assert_identical(actual, expected) + + @pytest.mark.skip("needs to be fixed for shortcut=False, keep_attrs=False") + def test_groupby_reduce_attrs(self): + array = self.da + array.attrs["foo"] = "bar" + + for shortcut in [True, False]: + for keep_attrs in [True, False]: + print(f"shortcut={shortcut}, keep_attrs={keep_attrs}") + actual = array.groupby("abc").reduce( + np.mean, keep_attrs=keep_attrs, shortcut=shortcut + ) + expected = array.groupby("abc").mean() + if keep_attrs: + expected.attrs["foo"] = "bar" + assert_identical(expected, actual) + + def test_groupby_map_center(self): + def center(x): + return x - np.mean(x) + + array = self.da + grouped = array.groupby("abc") + + expected_ds = array.to_dataset() + exp_data = np.hstack( + [center(self.x[:, :9]), center(self.x[:, 9:10]), center(self.x[:, 10:])] + ) + expected_ds["foo"] = (["x", "y"], exp_data) + expected_centered = expected_ds["foo"] + assert_allclose(expected_centered, grouped.map(center)) + + def test_groupby_map_ndarray(self): + # regression test for #326 + array = self.da + grouped = array.groupby("abc") + actual = grouped.map(np.asarray) + assert_equal(array, actual) + + def test_groupby_map_changes_metadata(self): + def change_metadata(x): + x.coords["x"] = x.coords["x"] * 2 + x.attrs["fruit"] = "lemon" + return x + + array = self.da + grouped = array.groupby("abc") + actual = grouped.map(change_metadata) + expected = array.copy() + expected = change_metadata(expected) + assert_equal(expected, actual) + + def test_groupby_math(self): + array = self.da + for squeeze in [True, False]: + grouped = array.groupby("x", squeeze=squeeze) + + expected = array + array.coords["x"] + actual = grouped + array.coords["x"] + assert_identical(expected, actual) + + actual = array.coords["x"] + grouped + assert_identical(expected, actual) + + ds = array.coords["x"].to_dataset(name="X") + expected = array + ds + actual = grouped + ds + assert_identical(expected, actual) + + actual = ds + grouped + assert_identical(expected, actual) + + grouped = array.groupby("abc") + expected_agg = (grouped.mean(...) - np.arange(3)).rename(None) + actual = grouped - DataArray(range(3), [("abc", ["a", "b", "c"])]) + actual_agg = actual.groupby("abc").mean(...) + assert_allclose(expected_agg, actual_agg) + + with pytest.raises(TypeError, match=r"only support binary ops"): + grouped + 1 + with pytest.raises(TypeError, match=r"only support binary ops"): + grouped + grouped + with pytest.raises(TypeError, match=r"in-place operations"): + array += grouped + + def test_groupby_math_not_aligned(self): + array = DataArray( + range(4), {"b": ("x", [0, 0, 1, 1]), "x": [0, 1, 2, 3]}, dims="x" + ) + other = DataArray([10], coords={"b": [0]}, dims="b") + actual = array.groupby("b") + other + expected = DataArray([10, 11, np.nan, np.nan], array.coords) + assert_identical(expected, actual) + + other = DataArray([10], coords={"c": 123, "b": [0]}, dims="b") + actual = array.groupby("b") + other + expected.coords["c"] = (["x"], [123] * 2 + [np.nan] * 2) + assert_identical(expected, actual) + + other = Dataset({"a": ("b", [10])}, {"b": [0]}) + actual = array.groupby("b") + other + expected = Dataset({"a": ("x", [10, 11, np.nan, np.nan])}, array.coords) + assert_identical(expected, actual) + + def test_groupby_restore_dim_order(self): + array = DataArray( + np.random.randn(5, 3), + coords={"a": ("x", range(5)), "b": ("y", range(3))}, + dims=["x", "y"], + ) + for by, expected_dims in [ + ("x", ("x", "y")), + ("y", ("x", "y")), + ("a", ("a", "y")), + ("b", ("x", "b")), + ]: + result = array.groupby(by).map(lambda x: x.squeeze()) + assert result.dims == expected_dims + + def test_groupby_restore_coord_dims(self): + array = DataArray( + np.random.randn(5, 3), + coords={ + "a": ("x", range(5)), + "b": ("y", range(3)), + "c": (("x", "y"), np.random.randn(5, 3)), + }, + dims=["x", "y"], + ) + + for by, expected_dims in [ + ("x", ("x", "y")), + ("y", ("x", "y")), + ("a", ("a", "y")), + ("b", ("x", "b")), + ]: + result = array.groupby(by, restore_coord_dims=True).map( + lambda x: x.squeeze() + )["c"] + assert result.dims == expected_dims + + def test_groupby_first_and_last(self): + array = DataArray([1, 2, 3, 4, 5], dims="x") + by = DataArray(["a"] * 2 + ["b"] * 3, dims="x", name="ab") + + expected = DataArray([1, 3], [("ab", ["a", "b"])]) + actual = array.groupby(by).first() + assert_identical(expected, actual) + + expected = DataArray([2, 5], [("ab", ["a", "b"])]) + actual = array.groupby(by).last() + assert_identical(expected, actual) + + array = DataArray(np.random.randn(5, 3), dims=["x", "y"]) + expected = DataArray(array[[0, 2]], {"ab": ["a", "b"]}, ["ab", "y"]) + actual = array.groupby(by).first() + assert_identical(expected, actual) + + actual = array.groupby("x").first() + expected = array # should be a no-op + assert_identical(expected, actual) + + def make_groupby_multidim_example_array(self): + return DataArray( + [[[0, 1], [2, 3]], [[5, 10], [15, 20]]], + coords={ + "lon": (["ny", "nx"], [[30, 40], [40, 50]]), + "lat": (["ny", "nx"], [[10, 10], [20, 20]]), + }, + dims=["time", "ny", "nx"], + ) + + def test_groupby_multidim(self): + array = self.make_groupby_multidim_example_array() + for dim, expected_sum in [ + ("lon", DataArray([5, 28, 23], coords=[("lon", [30.0, 40.0, 50.0])])), + ("lat", DataArray([16, 40], coords=[("lat", [10.0, 20.0])])), + ]: + actual_sum = array.groupby(dim).sum(...) + assert_identical(expected_sum, actual_sum) + + def test_groupby_multidim_map(self): + array = self.make_groupby_multidim_example_array() + actual = array.groupby("lon").map(lambda x: x - x.mean()) + expected = DataArray( + [[[-2.5, -6.0], [-5.0, -8.5]], [[2.5, 3.0], [8.0, 8.5]]], + coords=array.coords, + dims=array.dims, + ) + assert_identical(expected, actual) + + def test_groupby_bins(self): + array = DataArray(np.arange(4), dims="dim_0") + # the first value should not be part of any group ("right" binning) + array[0] = 99 + # bins follow conventions for pandas.cut + # http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html + bins = [0, 1.5, 5] + bin_coords = pd.cut(array["dim_0"], bins).categories + expected = DataArray( + [1, 5], dims="dim_0_bins", coords={"dim_0_bins": bin_coords} + ) + # the problem with this is that it overwrites the dimensions of array! + # actual = array.groupby('dim_0', bins=bins).sum() + actual = array.groupby_bins("dim_0", bins).map(lambda x: x.sum()) + assert_identical(expected, actual) + # make sure original array dims are unchanged + assert len(array.dim_0) == 4 + + def test_groupby_bins_empty(self): + array = DataArray(np.arange(4), [("x", range(4))]) + # one of these bins will be empty + bins = [0, 4, 5] + bin_coords = pd.cut(array["x"], bins).categories + actual = array.groupby_bins("x", bins).sum() + expected = DataArray([6, np.nan], dims="x_bins", coords={"x_bins": bin_coords}) + assert_identical(expected, actual) + # make sure original array is unchanged + # (was a problem in earlier versions) + assert len(array.x) == 4 + + def test_groupby_bins_multidim(self): + array = self.make_groupby_multidim_example_array() + bins = [0, 15, 20] + bin_coords = pd.cut(array["lat"].values.flat, bins).categories + expected = DataArray([16, 40], dims="lat_bins", coords={"lat_bins": bin_coords}) + actual = array.groupby_bins("lat", bins).map(lambda x: x.sum()) + assert_identical(expected, actual) + # modify the array coordinates to be non-monotonic after unstacking + array["lat"].data = np.array([[10.0, 20.0], [20.0, 10.0]]) + expected = DataArray([28, 28], dims="lat_bins", coords={"lat_bins": bin_coords}) + actual = array.groupby_bins("lat", bins).map(lambda x: x.sum()) + assert_identical(expected, actual) + + def test_groupby_bins_sort(self): + data = xr.DataArray( + np.arange(100), dims="x", coords={"x": np.linspace(-100, 100, num=100)} + ) + binned_mean = data.groupby_bins("x", bins=11).mean() + assert binned_mean.to_index().is_monotonic + + def test_groupby_assign_coords(self): + + array = DataArray([1, 2, 3, 4], {"c": ("x", [0, 0, 1, 1])}, dims="x") + actual = array.groupby("c").assign_coords(d=lambda a: a.mean()) + expected = array.copy() + expected.coords["d"] = ("x", [1.5, 1.5, 3.5, 3.5]) + assert_identical(actual, expected) + + def test_groupby_fillna(self): + a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x") + fill_value = DataArray([0, 1], dims="y") + actual = a.fillna(fill_value) + expected = DataArray( + [[0, 1], [1, 1], [0, 1], [3, 3]], coords={"x": range(4)}, dims=("x", "y") + ) + assert_identical(expected, actual) + + b = DataArray(range(4), coords={"x": range(4)}, dims="x") + expected = b.copy() + for target in [a, expected]: + target.coords["b"] = ("x", [0, 0, 1, 1]) + actual = a.groupby("b").fillna(DataArray([0, 2], dims="b")) + assert_identical(expected, actual) + + +class TestDataArrayResample: + def test_resample(self): + times = pd.date_range("2000-01-01", freq="6H", periods=10) + array = DataArray(np.arange(10), [("time", times)]) + + actual = array.resample(time="24H").mean() + expected = DataArray(array.to_series().resample("24H").mean()) + assert_identical(expected, actual) + + actual = array.resample(time="24H").reduce(np.mean) + assert_identical(expected, actual) + + # Our use of `loffset` may change if we align our API with pandas' changes. + # ref https://github.com/pydata/xarray/pull/4537 + actual = array.resample(time="24H", loffset="-12H").mean() + expected_ = array.to_series().resample("24H").mean() + expected_.index += to_offset("-12H") + expected = DataArray.from_series(expected_) + assert_identical(actual, expected) + + with pytest.raises(ValueError, match=r"index must be monotonic"): + array[[2, 0, 1]].resample(time="1D") + + def test_da_resample_func_args(self): + def func(arg1, arg2, arg3=0.0): + return arg1.mean("time") + arg2 + arg3 + + times = pd.date_range("2000", periods=3, freq="D") + da = xr.DataArray([1.0, 1.0, 1.0], coords=[times], dims=["time"]) + expected = xr.DataArray([3.0, 3.0, 3.0], coords=[times], dims=["time"]) + actual = da.resample(time="D").map(func, args=(1.0,), arg3=1.0) + assert_identical(actual, expected) + + def test_resample_first(self): + times = pd.date_range("2000-01-01", freq="6H", periods=10) + array = DataArray(np.arange(10), [("time", times)]) + + actual = array.resample(time="1D").first() + expected = DataArray([0, 4, 8], [("time", times[::4])]) + assert_identical(expected, actual) + + # verify that labels don't use the first value + actual = array.resample(time="24H").first() + expected = DataArray(array.to_series().resample("24H").first()) + assert_identical(expected, actual) + + # missing values + array = array.astype(float) + array[:2] = np.nan + actual = array.resample(time="1D").first() + expected = DataArray([2, 4, 8], [("time", times[::4])]) + assert_identical(expected, actual) + + actual = array.resample(time="1D").first(skipna=False) + expected = DataArray([np.nan, 4, 8], [("time", times[::4])]) + assert_identical(expected, actual) + + # regression test for http://stackoverflow.com/questions/33158558/ + array = Dataset({"time": times})["time"] + actual = array.resample(time="1D").last() + expected_times = pd.to_datetime( + ["2000-01-01T18", "2000-01-02T18", "2000-01-03T06"] + ) + expected = DataArray(expected_times, [("time", times[::4])], name="time") + assert_identical(expected, actual) + + def test_resample_bad_resample_dim(self): + times = pd.date_range("2000-01-01", freq="6H", periods=10) + array = DataArray(np.arange(10), [("__resample_dim__", times)]) + with pytest.raises(ValueError, match=r"Proxy resampling dimension"): + array.resample(**{"__resample_dim__": "1D"}).first() + + @requires_scipy + def test_resample_drop_nondim_coords(self): + xs = np.arange(6) + ys = np.arange(3) + times = pd.date_range("2000-01-01", freq="6H", periods=5) + data = np.tile(np.arange(5), (6, 3, 1)) + xx, yy = np.meshgrid(xs * 5, ys * 2.5) + tt = np.arange(len(times), dtype=int) + array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) + xcoord = DataArray(xx.T, {"x": xs, "y": ys}, ("x", "y")) + ycoord = DataArray(yy.T, {"x": xs, "y": ys}, ("x", "y")) + tcoord = DataArray(tt, {"time": times}, ("time",)) + ds = Dataset({"data": array, "xc": xcoord, "yc": ycoord, "tc": tcoord}) + ds = ds.set_coords(["xc", "yc", "tc"]) + + # Select the data now, with the auxiliary coordinates in place + array = ds["data"] + + # Re-sample + actual = array.resample(time="12H", restore_coord_dims=True).mean("time") + assert "tc" not in actual.coords + + # Up-sample - filling + actual = array.resample(time="1H", restore_coord_dims=True).ffill() + assert "tc" not in actual.coords + + # Up-sample - interpolation + actual = array.resample(time="1H", restore_coord_dims=True).interpolate( + "linear" + ) + assert "tc" not in actual.coords + + def test_resample_keep_attrs(self): + times = pd.date_range("2000-01-01", freq="6H", periods=10) + array = DataArray(np.ones(10), [("time", times)]) + array.attrs["meta"] = "data" + + result = array.resample(time="1D").mean(keep_attrs=True) + expected = DataArray([1, 1, 1], [("time", times[::4])], attrs=array.attrs) + assert_identical(result, expected) + + with pytest.warns( + UserWarning, match="Passing ``keep_attrs`` to ``resample`` has no effect." + ): + array.resample(time="1D", keep_attrs=True) + + def test_resample_skipna(self): + times = pd.date_range("2000-01-01", freq="6H", periods=10) + array = DataArray(np.ones(10), [("time", times)]) + array[1] = np.nan + + result = array.resample(time="1D").mean(skipna=False) + expected = DataArray([np.nan, 1, 1], [("time", times[::4])]) + assert_identical(result, expected) + + def test_upsample(self): + times = pd.date_range("2000-01-01", freq="6H", periods=5) + array = DataArray(np.arange(5), [("time", times)]) + + # Forward-fill + actual = array.resample(time="3H").ffill() + expected = DataArray(array.to_series().resample("3H").ffill()) + assert_identical(expected, actual) + + # Backward-fill + actual = array.resample(time="3H").bfill() + expected = DataArray(array.to_series().resample("3H").bfill()) + assert_identical(expected, actual) + + # As frequency + actual = array.resample(time="3H").asfreq() + expected = DataArray(array.to_series().resample("3H").asfreq()) + assert_identical(expected, actual) + + # Pad + actual = array.resample(time="3H").pad() + expected = DataArray(array.to_series().resample("3H").pad()) + assert_identical(expected, actual) + + # Nearest + rs = array.resample(time="3H") + actual = rs.nearest() + new_times = rs._full_index + expected = DataArray(array.reindex(time=new_times, method="nearest")) + assert_identical(expected, actual) + + def test_upsample_nd(self): + # Same as before, but now we try on multi-dimensional DataArrays. + xs = np.arange(6) + ys = np.arange(3) + times = pd.date_range("2000-01-01", freq="6H", periods=5) + data = np.tile(np.arange(5), (6, 3, 1)) + array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) + + # Forward-fill + actual = array.resample(time="3H").ffill() + expected_data = np.repeat(data, 2, axis=-1) + expected_times = times.to_series().resample("3H").asfreq().index + expected_data = expected_data[..., : len(expected_times)] + expected = DataArray( + expected_data, + {"time": expected_times, "x": xs, "y": ys}, + ("x", "y", "time"), + ) + assert_identical(expected, actual) + + # Backward-fill + actual = array.resample(time="3H").ffill() + expected_data = np.repeat(np.flipud(data.T).T, 2, axis=-1) + expected_data = np.flipud(expected_data.T).T + expected_times = times.to_series().resample("3H").asfreq().index + expected_data = expected_data[..., : len(expected_times)] + expected = DataArray( + expected_data, + {"time": expected_times, "x": xs, "y": ys}, + ("x", "y", "time"), + ) + assert_identical(expected, actual) + + # As frequency + actual = array.resample(time="3H").asfreq() + expected_data = np.repeat(data, 2, axis=-1).astype(float)[..., :-1] + expected_data[..., 1::2] = np.nan + expected_times = times.to_series().resample("3H").asfreq().index + expected = DataArray( + expected_data, + {"time": expected_times, "x": xs, "y": ys}, + ("x", "y", "time"), + ) + assert_identical(expected, actual) + + # Pad + actual = array.resample(time="3H").pad() + expected_data = np.repeat(data, 2, axis=-1) + expected_data[..., 1::2] = expected_data[..., ::2] + expected_data = expected_data[..., :-1] + expected_times = times.to_series().resample("3H").asfreq().index + expected = DataArray( + expected_data, + {"time": expected_times, "x": xs, "y": ys}, + ("x", "y", "time"), + ) + assert_identical(expected, actual) + + def test_upsample_tolerance(self): + # Test tolerance keyword for upsample methods bfill, pad, nearest + times = pd.date_range("2000-01-01", freq="1D", periods=2) + times_upsampled = pd.date_range("2000-01-01", freq="6H", periods=5) + array = DataArray(np.arange(2), [("time", times)]) + + # Forward fill + actual = array.resample(time="6H").ffill(tolerance="12H") + expected = DataArray([0.0, 0.0, 0.0, np.nan, 1.0], [("time", times_upsampled)]) + assert_identical(expected, actual) + + # Backward fill + actual = array.resample(time="6H").bfill(tolerance="12H") + expected = DataArray([0.0, np.nan, 1.0, 1.0, 1.0], [("time", times_upsampled)]) + assert_identical(expected, actual) + + # Nearest + actual = array.resample(time="6H").nearest(tolerance="6H") + expected = DataArray([0, 0, np.nan, 1, 1], [("time", times_upsampled)]) + assert_identical(expected, actual) + + @requires_scipy + def test_upsample_interpolate(self): + from scipy.interpolate import interp1d + + xs = np.arange(6) + ys = np.arange(3) + times = pd.date_range("2000-01-01", freq="6H", periods=5) + + z = np.arange(5) ** 2 + data = np.tile(z, (6, 3, 1)) + array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) + + expected_times = times.to_series().resample("1H").asfreq().index + # Split the times into equal sub-intervals to simulate the 6 hour + # to 1 hour up-sampling + new_times_idx = np.linspace(0, len(times) - 1, len(times) * 5) + for kind in ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]: + actual = array.resample(time="1H").interpolate(kind) + f = interp1d( + np.arange(len(times)), + data, + kind=kind, + axis=-1, + bounds_error=True, + assume_sorted=True, + ) + expected_data = f(new_times_idx) + expected = DataArray( + expected_data, + {"time": expected_times, "x": xs, "y": ys}, + ("x", "y", "time"), + ) + # Use AllClose because there are some small differences in how + # we upsample timeseries versus the integer indexing as I've + # done here due to floating point arithmetic + assert_allclose(expected, actual, rtol=1e-16) + + @requires_scipy + def test_upsample_interpolate_bug_2197(self): + dates = pd.date_range("2007-02-01", "2007-03-01", freq="D") + da = xr.DataArray(np.arange(len(dates)), [("time", dates)]) + result = da.resample(time="M").interpolate("linear") + expected_times = np.array( + [np.datetime64("2007-02-28"), np.datetime64("2007-03-31")] + ) + expected = xr.DataArray([27.0, np.nan], [("time", expected_times)]) + assert_equal(result, expected) + + @requires_scipy + def test_upsample_interpolate_regression_1605(self): + dates = pd.date_range("2016-01-01", "2016-03-31", freq="1D") + expected = xr.DataArray( + np.random.random((len(dates), 2, 3)), + dims=("time", "x", "y"), + coords={"time": dates}, + ) + actual = expected.resample(time="1D").interpolate("linear") + assert_allclose(actual, expected, rtol=1e-16) + + @requires_dask + @requires_scipy + @pytest.mark.parametrize("chunked_time", [True, False]) + def test_upsample_interpolate_dask(self, chunked_time): + from scipy.interpolate import interp1d + + xs = np.arange(6) + ys = np.arange(3) + times = pd.date_range("2000-01-01", freq="6H", periods=5) + + z = np.arange(5) ** 2 + data = np.tile(z, (6, 3, 1)) + array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) + chunks = {"x": 2, "y": 1} + if chunked_time: + chunks["time"] = 3 + + expected_times = times.to_series().resample("1H").asfreq().index + # Split the times into equal sub-intervals to simulate the 6 hour + # to 1 hour up-sampling + new_times_idx = np.linspace(0, len(times) - 1, len(times) * 5) + for kind in ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]: + actual = array.chunk(chunks).resample(time="1H").interpolate(kind) + actual = actual.compute() + f = interp1d( + np.arange(len(times)), + data, + kind=kind, + axis=-1, + bounds_error=True, + assume_sorted=True, + ) + expected_data = f(new_times_idx) + expected = DataArray( + expected_data, + {"time": expected_times, "x": xs, "y": ys}, + ("x", "y", "time"), + ) + # Use AllClose because there are some small differences in how + # we upsample timeseries versus the integer indexing as I've + # done here due to floating point arithmetic + assert_allclose(expected, actual, rtol=1e-16) + + +class TestDatasetResample: + def test_resample_and_first(self): + times = pd.date_range("2000-01-01", freq="6H", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) + + actual = ds.resample(time="1D").first(keep_attrs=True) + expected = ds.isel(time=[0, 4, 8]) + assert_identical(expected, actual) + + # upsampling + expected_time = pd.date_range("2000-01-01", freq="3H", periods=19) + expected = ds.reindex(time=expected_time) + actual = ds.resample(time="3H") + for how in ["mean", "sum", "first", "last"]: + method = getattr(actual, how) + result = method() + assert_equal(expected, result) + for method in [np.mean]: + result = actual.reduce(method) + assert_equal(expected, result) + + def test_resample_min_count(self): + times = pd.date_range("2000-01-01", freq="6H", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) + # inject nan + ds["foo"] = xr.where(ds["foo"] > 2.0, np.nan, ds["foo"]) + + actual = ds.resample(time="1D").sum(min_count=1) + expected = xr.concat( + [ + ds.isel(time=slice(i * 4, (i + 1) * 4)).sum("time", min_count=1) + for i in range(3) + ], + dim=actual["time"], + ) + assert_equal(expected, actual) + + def test_resample_by_mean_with_keep_attrs(self): + times = pd.date_range("2000-01-01", freq="6H", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) + ds.attrs["dsmeta"] = "dsdata" + + resampled_ds = ds.resample(time="1D").mean(keep_attrs=True) + actual = resampled_ds["bar"].attrs + expected = ds["bar"].attrs + assert expected == actual + + actual = resampled_ds.attrs + expected = ds.attrs + assert expected == actual + + with pytest.warns( + UserWarning, match="Passing ``keep_attrs`` to ``resample`` has no effect." + ): + ds.resample(time="1D", keep_attrs=True) + + def test_resample_loffset(self): + times = pd.date_range("2000-01-01", freq="6H", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) + ds.attrs["dsmeta"] = "dsdata" + + # Our use of `loffset` may change if we align our API with pandas' changes. + # ref https://github.com/pydata/xarray/pull/4537 + actual = ds.resample(time="24H", loffset="-12H").mean().bar + expected_ = ds.bar.to_series().resample("24H").mean() + expected_.index += to_offset("-12H") + expected = DataArray.from_series(expected_) + assert_allclose(actual, expected) + + def test_resample_by_mean_discarding_attrs(self): + times = pd.date_range("2000-01-01", freq="6H", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) + ds.attrs["dsmeta"] = "dsdata" + + resampled_ds = ds.resample(time="1D").mean(keep_attrs=False) + + assert resampled_ds["bar"].attrs == {} + assert resampled_ds.attrs == {} + + def test_resample_by_last_discarding_attrs(self): + times = pd.date_range("2000-01-01", freq="6H", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) + ds.attrs["dsmeta"] = "dsdata" + + resampled_ds = ds.resample(time="1D").last(keep_attrs=False) + + assert resampled_ds["bar"].attrs == {} + assert resampled_ds.attrs == {} + + @requires_scipy + def test_resample_drop_nondim_coords(self): + xs = np.arange(6) + ys = np.arange(3) + times = pd.date_range("2000-01-01", freq="6H", periods=5) + data = np.tile(np.arange(5), (6, 3, 1)) + xx, yy = np.meshgrid(xs * 5, ys * 2.5) + tt = np.arange(len(times), dtype=int) + array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) + xcoord = DataArray(xx.T, {"x": xs, "y": ys}, ("x", "y")) + ycoord = DataArray(yy.T, {"x": xs, "y": ys}, ("x", "y")) + tcoord = DataArray(tt, {"time": times}, ("time",)) + ds = Dataset({"data": array, "xc": xcoord, "yc": ycoord, "tc": tcoord}) + ds = ds.set_coords(["xc", "yc", "tc"]) + + # Re-sample + actual = ds.resample(time="12H").mean("time") + assert "tc" not in actual.coords + + # Up-sample - filling + actual = ds.resample(time="1H").ffill() + assert "tc" not in actual.coords + + # Up-sample - interpolation + actual = ds.resample(time="1H").interpolate("linear") + assert "tc" not in actual.coords + + def test_resample_old_api(self): + + times = pd.date_range("2000-01-01", freq="6H", periods=10) + ds = Dataset( + { + "foo": (["time", "x", "y"], np.random.randn(10, 5, 3)), + "bar": ("time", np.random.randn(10), {"meta": "data"}), + "time": times, + } + ) + + with pytest.raises(TypeError, match=r"resample\(\) no longer supports"): + ds.resample("1D", "time") + + with pytest.raises(TypeError, match=r"resample\(\) no longer supports"): + ds.resample("1D", dim="time", how="mean") + + with pytest.raises(TypeError, match=r"resample\(\) no longer supports"): + ds.resample("1D", dim="time") + + def test_resample_ds_da_are_the_same(self): + time = pd.date_range("2000-01-01", freq="6H", periods=365 * 4) + ds = xr.Dataset( + { + "foo": (("time", "x"), np.random.randn(365 * 4, 5)), + "time": time, + "x": np.arange(5), + } + ) + assert_identical( + ds.resample(time="M").mean()["foo"], ds.foo.resample(time="M").mean() + ) + + def test_ds_resample_apply_func_args(self): + def func(arg1, arg2, arg3=0.0): + return arg1.mean("time") + arg2 + arg3 + + times = pd.date_range("2000", freq="D", periods=3) + ds = xr.Dataset({"foo": ("time", [1.0, 1.0, 1.0]), "time": times}) + expected = xr.Dataset({"foo": ("time", [3.0, 3.0, 3.0]), "time": times}) + actual = ds.resample(time="D").map(func, args=(1.0,), arg3=1.0) + assert_identical(expected, actual) + + # TODO: move other groupby tests from test_dataset and test_dataarray over here From e26aec9500e04f3b926b248988b976dbfcb9c632 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Fri, 20 Aug 2021 00:27:39 +0200 Subject: [PATCH 10/20] Move docstring for xr.set_options to numpy style (#5702) * Move docstring to numpy style * Update options.py * Update options.py --- xarray/core/options.py | 120 +++++++++++++++++++++++------------------ 1 file changed, 69 insertions(+), 51 deletions(-) diff --git a/xarray/core/options.py b/xarray/core/options.py index 524362ce924..e22d8ed99d8 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -151,57 +151,75 @@ def _get_keep_attrs(default): class set_options: - """Set options for xarray in a controlled context. - - Currently supported options: - - - ``display_width``: maximum display width for ``repr`` on xarray objects. - Default: ``80``. - - ``display_max_rows``: maximum display rows. Default: ``12``. - - ``arithmetic_join``: DataArray/Dataset alignment in binary operations. - Default: ``'inner'``. - - ``file_cache_maxsize``: maximum number of open files to hold in xarray's - global least-recently-usage cached. This should be smaller than your - system's per-process file descriptor limit, e.g., ``ulimit -n`` on Linux. - Default: 128. - - ``warn_for_unclosed_files``: whether or not to issue a warning when - unclosed files are deallocated (default False). This is mostly useful - for debugging. - - ``cmap_sequential``: colormap to use for nondivergent data plots. - Default: ``viridis``. If string, must be matplotlib built-in colormap. - Can also be a Colormap object (e.g. mpl.cm.magma) - - ``cmap_divergent``: colormap to use for divergent data plots. - Default: ``RdBu_r``. If string, must be matplotlib built-in colormap. - Can also be a Colormap object (e.g. mpl.cm.magma) - - ``keep_attrs``: rule for whether to keep attributes on xarray - Datasets/dataarrays after operations. Either ``True`` to always keep - attrs, ``False`` to always discard them, or ``'default'`` to use original - logic that attrs should only be kept in unambiguous circumstances. - Default: ``'default'``. - - ``use_bottleneck``: allow using bottleneck. Either ``True`` to accelerate - operations using bottleneck if it is installed or ``False`` to never use it. - Default: ``True`` - - ``display_style``: display style to use in jupyter for xarray objects. - Default: ``'html'``. Other options are ``'text'``. - - ``display_expand_attrs``: whether to expand the attributes section for - display of ``DataArray`` or ``Dataset`` objects. Can be ``True`` to always - expand, ``False`` to always collapse, or ``default`` to expand unless over - a pre-defined limit. Default: ``default``. - - ``display_expand_coords``: whether to expand the coordinates section for - display of ``DataArray`` or ``Dataset`` objects. Can be ``True`` to always - expand, ``False`` to always collapse, or ``default`` to expand unless over - a pre-defined limit. Default: ``default``. - - ``display_expand_data``: whether to expand the data section for display - of ``DataArray`` objects. Can be ``True`` to always expand, ``False`` to - always collapse, or ``default`` to expand unless over a pre-defined limit. - Default: ``default``. - - ``display_expand_data_vars``: whether to expand the data variables section - for display of ``Dataset`` objects. Can be ``True`` to always - expand, ``False`` to always collapse, or ``default`` to expand unless over - a pre-defined limit. Default: ``default``. - - - You can use ``set_options`` either as a context manager: + """ + Set options for xarray in a controlled context. + + Parameters + ---------- + display_width : int, default: 80 + Maximum display width for ``repr`` on xarray objects. + display_max_rows : int, default: 12 + Maximum display rows. + arithmetic_join : {"inner", "outer", "left", "right", "exact"} + DataArray/Dataset alignment in binary operations. + file_cache_maxsize : int, default: 128 + Maximum number of open files to hold in xarray's + global least-recently-usage cached. This should be smaller than + your system's per-process file descriptor limit, e.g., + ``ulimit -n`` on Linux. + warn_for_unclosed_files : bool, default: False + Whether or not to issue a warning when unclosed files are + deallocated. This is mostly useful for debugging. + cmap_sequential : str or matplotlib.colors.Colormap, default: "viridis" + Colormap to use for nondivergent data plots. If string, must be + matplotlib built-in colormap. Can also be a Colormap object + (e.g. mpl.cm.magma) + cmap_divergent : str or matplotlib.colors.Colormap, default: "RdBu_r" + Colormap to use for divergent data plots. If string, must be + matplotlib built-in colormap. Can also be a Colormap object + (e.g. mpl.cm.magma) + keep_attrs : {"default", True, False} + Whether to keep attributes on xarray Datasets/dataarrays after + operations. Can be + + * ``True`` : to always keep attrs + * ``False`` : to always discard attrs + * ``default`` : to use original logic that attrs should only + be kept in unambiguous circumstances + display_style : {"text", "html"} + Display style to use in jupyter for xarray objects. + display_expand_attrs : {"default", True, False}: + Whether to expand the attributes section for display of + ``DataArray`` or ``Dataset`` objects. Can be + + * ``True`` : to always expand attrs + * ``False`` : to always collapse attrs + * ``default`` : to expand unless over a pre-defined limit + display_expand_coords : {"default", True, False}: + Whether to expand the coordinates section for display of + ``DataArray`` or ``Dataset`` objects. Can be + + * ``True`` : to always expand coordinates + * ``False`` : to always collapse coordinates + * ``default`` : to expand unless over a pre-defined limit + display_expand_data : {"default", True, False}: + Whether to expand the data section for display of ``DataArray`` + objects. Can be + + * ``True`` : to always expand data + * ``False`` : to always collapse data + * ``default`` : to expand unless over a pre-defined limit + display_expand_data_vars : {"default", True, False}: + Whether to expand the data variables section for display of + ``Dataset`` objects. Can be + + * ``True`` : to always expand data variables + * ``False`` : to always collapse data variables + * ``default`` : to expand unless over a pre-defined limit + + Examples + -------- + It is possible to use ``set_options`` either as a context manager: >>> ds = xr.Dataset({"x": np.arange(1000)}) >>> with xr.set_options(display_width=40): From 1434b8d04c5d6d07472dc9997a589353e7da4b06 Mon Sep 17 00:00:00 2001 From: keewis Date: Sat, 21 Aug 2021 13:24:37 +0200 Subject: [PATCH 11/20] extend show_versions (#5724) * print sparse in print_versions * add fsspec to print_versions * add cupy [skip-ci] --- xarray/util/print_versions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xarray/util/print_versions.py b/xarray/util/print_versions.py index cd5d425efe2..561126ea05f 100755 --- a/xarray/util/print_versions.py +++ b/xarray/util/print_versions.py @@ -118,7 +118,10 @@ def show_versions(file=sys.stdout): ("cartopy", lambda mod: mod.__version__), ("seaborn", lambda mod: mod.__version__), ("numbagg", lambda mod: mod.__version__), + ("fsspec", lambda mod: mod.__version__), + ("cupy", lambda mod: mod.__version__), ("pint", lambda mod: mod.__version__), + ("sparse", lambda mod: mod.__version__), # xarray setup/test ("setuptools", lambda mod: mod.__version__), ("pip", lambda mod: mod.__version__), From 48a9dbe7d8dc2361bc985dd9fb1193a26135b310 Mon Sep 17 00:00:00 2001 From: Akio Taniguchi Date: Sun, 22 Aug 2021 00:28:31 +0900 Subject: [PATCH 12/20] Add xarray-dataclasses to ecosystem in docs (#5725) * Remove xarray-custom from ecosystem in docs * Add xarray-dataclasses to ecosystem in docs --- doc/ecosystem.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/ecosystem.rst b/doc/ecosystem.rst index 01f5c29b9f5..9e81679f693 100644 --- a/doc/ecosystem.rst +++ b/doc/ecosystem.rst @@ -68,7 +68,7 @@ Extend xarray capabilities - `hypothesis-gufunc `_: Extension to hypothesis. Makes it easy to write unit tests with xarray objects as input. - `nxarray `_: NeXus input/output capability for xarray. - `xarray-compare `_: xarray extension for data comparison. -- `xarray-custom `_: Data classes for custom xarray creation. +- `xarray-dataclasses `_: xarray extension for typed DataArray and Dataset creation. - `xarray_extras `_: Advanced algorithms for xarray objects (e.g. integrations/interpolations). - `xpublish `_: Publish Xarray Datasets via a Zarr compatible REST API. - `xrft `_: Fourier transforms for xarray data. From 28bdcf032f34e358ae4e145267fbcb3c5eadae91 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sat, 21 Aug 2021 15:34:08 -0700 Subject: [PATCH 13/20] Xfail failing test on main (#5729) --- xarray/tests/test_sparse.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/tests/test_sparse.py b/xarray/tests/test_sparse.py index 7401b15da42..044600a7f3c 100644 --- a/xarray/tests/test_sparse.py +++ b/xarray/tests/test_sparse.py @@ -854,6 +854,7 @@ def test_sparse_coords(self): ) +@pytest.mark.xfail(reason="https://github.com/pydata/xarray/issues/5654") @requires_dask def test_chunk(): s = sparse.COO.from_numpy(np.array([0, 0, 1, 2])) From a78c1e0115d38cb4461fd1aba93334d440cff49c Mon Sep 17 00:00:00 2001 From: Stefan Bender Date: Sun, 22 Aug 2021 00:51:02 +0200 Subject: [PATCH 14/20] dataset `__repr__` updates (#5580) * Delegate `max_rows` from dataset `__repr__` As discussed in https://github.com/pydata/xarray/issues/5545 the default setting of `display_max_rows` is sometimes less useful to inspect datasets for completeness, and changing the option is not backwards compatible. In addition, a concise and short output dataset seems to be preferred by most people. The compromise is to keep the dataset's `__repr__` short and tunable via `xr.set_options(display_max_rows=...)`, and at the same time to enable the output of all items by explicitly requesting `ds.coords`, `ds.data_vars`, and `ds.attrs`. These explicit `__repr__`s also restore backwards compatibility in these cases. Slightly changes the internal implementation of `_mapping_repr()`: Setting (leaving) `max_rows` to `None` means "no limits". * tests: Update dataset `__repr__` tests [1/2] Updates the dataset `__repr__` test to assure that the dataset output honours the `display_max_rows` setting, not the `data_vars` output. Discussed in https://github.com/pydata/xarray/issues/5545 * tests: Extend dataset `__repr__` tests [2/2] Extends the dataset `__repr__` test to ensure that the output of `ds.coords`, `ds.data_vars`, and `ds.attrs` is of full length as desired. Includes more dimensions and coordinates to cover more cases. Discussed in https://github.com/pydata/xarray/issues/5545 * doc: Add what's new entry for `__repr__` changes Sorted as a "breaking change" for 0.18.3 for now. * Revert "doc: Add what's new entry for `__repr__` changes" This reverts commit 3dd645b3141cdb97567cc245aa836b22fab2b3da. * doc: Add what's new entry for `__repr__` changes Sorted as a "breaking change", for 0.19.1 for now. * doc: Remove `attrs` from `__repr__` changes Address comment from @keewis: `.attrs` is a standard python dict, so there's no custom repr. * tests: Remove `ds.attrs` formatting test According to @keewis: I don't think we need to test this because `attrs_repr` will only ever be called by `dataset_repr` / `array_repr`: on its own, the standard python `dict`'s `repr` will be used. * tests: Fix no. of coordinates in formatting_repr The number of coordinates changed to be the same as the number of variables, which only incidentally was fixed to 40. Updates the to-be-tested format string to use the same number of variables instead of the hard-coded one, which might be subject to change. --- doc/whats-new.rst | 5 +++++ xarray/core/formatting.py | 14 +++++++------- xarray/tests/test_formatting.py | 29 +++++++++++++++++++++-------- 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4f79a37eb4b..aa6ac4016a7 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -33,6 +33,11 @@ New Features Breaking changes ~~~~~~~~~~~~~~~~ +- The ``__repr__`` of a :py:class:`xarray.Dataset`'s ``coords`` and ``data_vars`` + ignore ``xarray.set_option(display_max_rows=...)`` and show the full output + when called directly as, e.g., ``ds.data_vars`` or ``print(ds.data_vars)`` + (:issue:`5545`, :pull:`5580`). + By `Stefan Bender `_. Deprecations ~~~~~~~~~~~~ diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 7f292605e63..70d1a61f56c 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -377,14 +377,12 @@ def _mapping_repr( ): if col_width is None: col_width = _calculate_col_width(mapping) - if max_rows is None: - max_rows = OPTIONS["display_max_rows"] summary = [f"{title}:"] if mapping: len_mapping = len(mapping) if not _get_boolean_with_default(expand_option_name, default=True): summary = [f"{summary[0]} ({len_mapping})"] - elif len_mapping > max_rows: + elif max_rows is not None and len_mapping > max_rows: summary = [f"{summary[0]} ({max_rows}/{len_mapping})"] first_rows = max_rows // 2 + max_rows % 2 keys = list(mapping.keys()) @@ -418,7 +416,7 @@ def _mapping_repr( ) -def coords_repr(coords, col_width=None): +def coords_repr(coords, col_width=None, max_rows=None): if col_width is None: col_width = _calculate_col_width(_get_col_items(coords)) return _mapping_repr( @@ -427,6 +425,7 @@ def coords_repr(coords, col_width=None): summarizer=summarize_coord, expand_option_name="display_expand_coords", col_width=col_width, + max_rows=max_rows, ) @@ -544,21 +543,22 @@ def dataset_repr(ds): summary = ["".format(type(ds).__name__)] col_width = _calculate_col_width(_get_col_items(ds.variables)) + max_rows = OPTIONS["display_max_rows"] dims_start = pretty_print("Dimensions:", col_width) summary.append("{}({})".format(dims_start, dim_summary(ds))) if ds.coords: - summary.append(coords_repr(ds.coords, col_width=col_width)) + summary.append(coords_repr(ds.coords, col_width=col_width, max_rows=max_rows)) unindexed_dims_str = unindexed_dims_repr(ds.dims, ds.coords) if unindexed_dims_str: summary.append(unindexed_dims_str) - summary.append(data_vars_repr(ds.data_vars, col_width=col_width)) + summary.append(data_vars_repr(ds.data_vars, col_width=col_width, max_rows=max_rows)) if ds.attrs: - summary.append(attrs_repr(ds.attrs)) + summary.append(attrs_repr(ds.attrs, max_rows=max_rows)) return "\n".join(summary) diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index b9ba57f99dc..d5e7d2ee232 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -509,15 +509,16 @@ def test__mapping_repr(display_max_rows, n_vars, n_attr): long_name = "long_name" a = np.core.defchararray.add(long_name, np.arange(0, n_vars).astype(str)) b = np.core.defchararray.add("attr_", np.arange(0, n_attr).astype(str)) + c = np.core.defchararray.add("coord", np.arange(0, n_vars).astype(str)) attrs = {k: 2 for k in b} - coords = dict(time=np.array([0, 1])) + coords = {_c: np.array([0, 1]) for _c in c} data_vars = dict() - for v in a: + for (v, _c) in zip(a, coords.items()): data_vars[v] = xr.DataArray( name=v, data=np.array([3, 4]), - dims=["time"], - coords=coords, + dims=[_c[0]], + coords=dict([_c]), ) ds = xr.Dataset(data_vars) ds.attrs = attrs @@ -525,25 +526,37 @@ def test__mapping_repr(display_max_rows, n_vars, n_attr): with xr.set_options(display_max_rows=display_max_rows): # Parse the data_vars print and show only data_vars rows: - summary = formatting.data_vars_repr(ds.data_vars).split("\n") + summary = formatting.dataset_repr(ds).split("\n") summary = [v for v in summary if long_name in v] - # The length should be less than or equal to display_max_rows: len_summary = len(summary) data_vars_print_size = min(display_max_rows, len_summary) assert len_summary == data_vars_print_size + summary = formatting.data_vars_repr(ds.data_vars).split("\n") + summary = [v for v in summary if long_name in v] + # The length should be equal to the number of data variables + len_summary = len(summary) + assert len_summary == n_vars + + summary = formatting.coords_repr(ds.coords).split("\n") + summary = [v for v in summary if "coord" in v] + # The length should be equal to the number of data variables + len_summary = len(summary) + assert len_summary == n_vars + with xr.set_options( display_expand_coords=False, display_expand_data_vars=False, display_expand_attrs=False, ): actual = formatting.dataset_repr(ds) + coord_s = ", ".join([f"{c}: {len(v)}" for c, v in coords.items()]) expected = dedent( f"""\ - Dimensions: (time: 2) - Coordinates: (1) + Dimensions: ({coord_s}) + Coordinates: ({n_vars}) Data variables: ({n_vars}) Attributes: ({n_attr})""" ) From befd1b98bd84047d62307419a30bcda7a0727926 Mon Sep 17 00:00:00 2001 From: Ray Bell Date: Sat, 21 Aug 2021 18:52:18 -0400 Subject: [PATCH 15/20] add storage_options arg to to_zarr (#5615) * add storage_options arg to to_zarr * add try import * add what's new * merge main whats-new * undo whats new * move import fsspec lower * fsspec in to_zarr * add a test. Co-authored-by: Zachary Blackwood Co-authored-by: Nathan Lis * add requires_zarr_2_5_0 * add what's new * add storage options arg to end Co-authored-by: Ray Bell Co-authored-by: Zachary Blackwood Co-authored-by: Nathan Lis --- doc/whats-new.rst | 3 +++ xarray/backends/api.py | 21 +++++++++++++++++++-- xarray/backends/zarr.py | 3 +++ xarray/core/dataset.py | 11 ++++++++--- xarray/tests/__init__.py | 1 + xarray/tests/test_backends.py | 12 ++++++++++++ 6 files changed, 46 insertions(+), 5 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index aa6ac4016a7..942ee51fff9 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -28,6 +28,9 @@ New Features By `Justus Magin `_. - Added ``**kwargs`` argument to :py:meth:`open_rasterio` to access overviews (:issue:`3269`). By `Pushkar Kopparla `_. +- Added ``storage_options`` argument to :py:meth:`to_zarr` (:issue:`5601`). + By `Ray Bell `_, `Zachary Blackwood `_ and + `Nathan Lis `_. Breaking changes diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 9b4fa8fce5a..2c9b25f860f 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1319,6 +1319,7 @@ def to_zarr( append_dim: Hashable = None, region: Mapping[str, slice] = None, safe_chunks: bool = True, + storage_options: Dict[str, str] = None, ): """This function creates an appropriate datastore for writing a dataset to a zarr ztore @@ -1330,6 +1331,22 @@ def to_zarr( store = _normalize_path(store) chunk_store = _normalize_path(chunk_store) + if storage_options is None: + mapper = store + chunk_mapper = chunk_store + else: + from fsspec import get_mapper + + if not isinstance(store, str): + raise ValueError( + f"store must be a string to use storage_options. Got {type(store)}" + ) + mapper = get_mapper(store, **storage_options) + if chunk_store is not None: + chunk_mapper = get_mapper(chunk_store, **storage_options) + else: + chunk_mapper = chunk_store + if encoding is None: encoding = {} @@ -1372,13 +1389,13 @@ def to_zarr( already_consolidated = False consolidate_on_close = consolidated or consolidated is None zstore = backends.ZarrStore.open_group( - store=store, + store=mapper, mode=mode, synchronizer=synchronizer, group=group, consolidated=already_consolidated, consolidate_on_close=consolidate_on_close, - chunk_store=chunk_store, + chunk_store=chunk_mapper, append_dim=append_dim, write_region=region, safe_chunks=safe_chunks, diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index aec12d2b154..12499103fb9 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -713,6 +713,9 @@ def open_zarr( falling back to read non-consolidated metadata if that fails. chunk_store : MutableMapping, optional A separate Zarr store only for chunk data. + storage_options : dict, optional + Any additional parameters for the storage backend (ignored for local + paths). decode_timedelta : bool, optional If True, decode variables and coordinates with time units in {'days', 'hours', 'minutes', 'seconds', 'milliseconds', 'microseconds'} diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 90c395ed39b..a5eaa82cfdd 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1922,6 +1922,7 @@ def to_zarr( append_dim: Hashable = None, region: Mapping[str, slice] = None, safe_chunks: bool = True, + storage_options: Dict[str, str] = None, ) -> "ZarrStore": """Write dataset contents to a zarr group. @@ -1941,10 +1942,10 @@ def to_zarr( Parameters ---------- store : MutableMapping, str or Path, optional - Store or path to directory in file system. + Store or path to directory in local or remote file system. chunk_store : MutableMapping, str or Path, optional - Store or path to directory in file system only for Zarr array chunks. - Requires zarr-python v2.4.0 or later. + Store or path to directory in local or remote file system only for Zarr + array chunks. Requires zarr-python v2.4.0 or later. mode : {"w", "w-", "a", "r+", None}, optional Persistence mode: "w" means create (overwrite if exists); "w-" means create (fail if exists); @@ -1999,6 +2000,9 @@ def to_zarr( if Zarr arrays are written in parallel. This option may be useful in combination with ``compute=False`` to initialize a Zarr from an existing Dataset with aribtrary chunk structure. + storage_options : dict, optional + Any additional parameters for the storage backend (ignored for local + paths). References ---------- @@ -2031,6 +2035,7 @@ def to_zarr( self, store=store, chunk_store=chunk_store, + storage_options=storage_options, mode=mode, synchronizer=synchronizer, group=group, diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index d757fb451cc..f610941914b 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -77,6 +77,7 @@ def LooseVersion(vstring): has_nc_time_axis, requires_nc_time_axis = _importorskip("nc_time_axis") has_rasterio, requires_rasterio = _importorskip("rasterio") has_zarr, requires_zarr = _importorskip("zarr") +has_zarr_2_5_0, requires_zarr_2_5_0 = _importorskip("zarr", minversion="2.5.0") has_fsspec, requires_fsspec = _importorskip("fsspec") has_iris, requires_iris = _importorskip("iris") has_cfgrib, requires_cfgrib = _importorskip("cfgrib") diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 3bbc2c93b31..3ca20cade56 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -71,6 +71,7 @@ requires_scipy, requires_scipy_or_netCDF4, requires_zarr, + requires_zarr_2_5_0, ) from .test_coding_times import ( _ALL_CALENDARS, @@ -2388,6 +2389,17 @@ def create_zarr_target(self): yield tmp +@requires_fsspec +@requires_zarr_2_5_0 +def test_zarr_storage_options(): + pytest.importorskip("aiobotocore") + ds = create_test_data() + store_target = "memory://test.zarr" + ds.to_zarr(store_target, storage_options={"test": "zarr_write"}) + ds_a = xr.open_zarr(store_target, storage_options={"test": "zarr_read"}) + assert_identical(ds, ds_a) + + @requires_scipy class TestScipyInMemoryData(CFEncodedBase, NetCDF3Only): engine = "scipy" From 6b59d9a5058adb8de1661f49bc273aa78d6de7de Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sat, 21 Aug 2021 16:38:14 -0700 Subject: [PATCH 16/20] Consolidate TypeVars in a single place (#5569) * Consolidate type bounds in a single place * More consolidation * Update xarray/core/types.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * Update xarray/core/types.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * Rename T_DSorDA to T_Xarray * Update xarray/core/weighted.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * Update xarray/core/rolling_exp.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * . Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/core/_typed_ops.pyi | 21 ++++++++++----------- xarray/core/common.py | 8 ++++---- xarray/core/computation.py | 15 ++++++--------- xarray/core/dataarray.py | 11 ++++++----- xarray/core/dataset.py | 8 +++----- xarray/core/parallel.py | 15 +++++++++------ xarray/core/rolling_exp.py | 21 +++++++++------------ xarray/core/types.py | 31 +++++++++++++++++++++++++++++++ xarray/core/variable.py | 35 ++++++++++++++--------------------- xarray/core/weighted.py | 27 +++++++++++++-------------- xarray/util/generate_ops.py | 21 +++++++++++---------- 11 files changed, 116 insertions(+), 97 deletions(-) create mode 100644 xarray/core/types.py diff --git a/xarray/core/_typed_ops.pyi b/xarray/core/_typed_ops.pyi index 4a6c2dc7b4e..0e0573a0afe 100644 --- a/xarray/core/_typed_ops.pyi +++ b/xarray/core/_typed_ops.pyi @@ -9,6 +9,16 @@ from .dataarray import DataArray from .dataset import Dataset from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy from .npcompat import ArrayLike +from .types import ( + DaCompatible, + DsCompatible, + GroupByIncompatible, + ScalarOrArray, + T_DataArray, + T_Dataset, + T_Variable, + VarCompatible, +) from .variable import Variable try: @@ -16,17 +26,6 @@ try: except ImportError: DaskArray = np.ndarray -# DatasetOpsMixin etc. are parent classes of Dataset etc. -T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin") -T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin") -T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin") - -ScalarOrArray = Union[ArrayLike, np.generic, np.ndarray, DaskArray] -DsCompatible = Union[Dataset, DataArray, Variable, GroupBy, ScalarOrArray] -DaCompatible = Union[DataArray, Variable, DataArrayGroupBy, ScalarOrArray] -VarCompatible = Union[Variable, ScalarOrArray] -GroupByIncompatible = Union[Variable, GroupBy] - class DatasetOpsMixin: __slots__ = () def _binary_op(self, other, f, reflexive=...): ... diff --git a/xarray/core/common.py b/xarray/core/common.py index d3001532aa0..74763829856 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings from contextlib import suppress from html import escape @@ -36,10 +38,10 @@ if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset + from .types import T_DataWithCoords, T_Xarray from .variable import Variable from .weighted import Weighted -T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") C = TypeVar("C") T = TypeVar("T") @@ -795,9 +797,7 @@ def groupby_bins( }, ) - def weighted( - self: T_DataWithCoords, weights: "DataArray" - ) -> "Weighted[T_DataWithCoords]": + def weighted(self: T_DataWithCoords, weights: "DataArray") -> Weighted[T_Xarray]: """ Weighted operations. diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 5bfb14793bb..9278577cbd6 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -21,7 +21,6 @@ Optional, Sequence, Tuple, - TypeVar, Union, ) @@ -36,11 +35,9 @@ from .variable import Variable if TYPE_CHECKING: - from .coordinates import Coordinates # noqa - from .dataarray import DataArray + from .coordinates import Coordinates from .dataset import Dataset - - T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) + from .types import T_Xarray _NO_FILL_VALUE = utils.ReprObject("") _DEFAULT_NAME = utils.ReprObject("") @@ -199,7 +196,7 @@ def result_name(objects: list) -> Any: return name -def _get_coords_list(args) -> List["Coordinates"]: +def _get_coords_list(args) -> List[Coordinates]: coords_list = [] for arg in args: try: @@ -400,8 +397,8 @@ def apply_dict_of_variables_vfunc( def _fast_dataset( - variables: Dict[Hashable, Variable], coord_variables: Mapping[Any, Variable] -) -> "Dataset": + variables: Dict[Hashable, Variable], coord_variables: Mapping[Hashable, Variable] +) -> Dataset: """Create a dataset as quickly as possible. Beware: the `variables` dict is modified INPLACE. @@ -1729,7 +1726,7 @@ def _calc_idxminmax( return res -def unify_chunks(*objects: T_DSorDA) -> Tuple[T_DSorDA, ...]: +def unify_chunks(*objects: T_Xarray) -> Tuple[T_Xarray, ...]: """ Given any number of Dataset and/or DataArray objects, returns new objects with unified chunk size along all chunked dimensions. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index e86bae08a3f..e0c1a1c9d17 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import warnings from typing import ( @@ -12,7 +14,6 @@ Optional, Sequence, Tuple, - TypeVar, Union, cast, ) @@ -70,8 +71,6 @@ assert_unique_multiindex_level_names, ) -T_DataArray = TypeVar("T_DataArray", bound="DataArray") -T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset) if TYPE_CHECKING: try: from dask.delayed import Delayed @@ -86,6 +85,8 @@ except ImportError: iris_Cube = None + from .types import T_DataArray, T_Xarray + def _infer_coords_and_dims( shape, coords, dims @@ -3698,11 +3699,11 @@ def unify_chunks(self) -> "DataArray": def map_blocks( self, - func: Callable[..., T_DSorDA], + func: Callable[..., T_Xarray], args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, template: Union["DataArray", "Dataset"] = None, - ) -> T_DSorDA: + ) -> T_Xarray: """ Apply a function to each block of this DataArray. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a5eaa82cfdd..98374fa2ba3 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -25,7 +25,6 @@ Sequence, Set, Tuple, - TypeVar, Union, cast, overload, @@ -109,8 +108,7 @@ from ..backends import AbstractDataStore, ZarrStore from .dataarray import DataArray from .merge import CoercibleMapping - - T_DSorDA = TypeVar("T_DSorDA", DataArray, "Dataset") + from .types import T_Xarray try: from dask.delayed import Delayed @@ -6630,11 +6628,11 @@ def unify_chunks(self) -> "Dataset": def map_blocks( self, - func: "Callable[..., T_DSorDA]", + func: "Callable[..., T_Xarray]", args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, template: Union["DataArray", "Dataset"] = None, - ) -> "T_DSorDA": + ) -> "T_Xarray": """ Apply a function to each block of this Dataset. diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 20ec3608ebb..4917714a9c2 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import collections import itertools import operator from typing import ( + TYPE_CHECKING, Any, Callable, DefaultDict, @@ -12,7 +15,6 @@ Mapping, Sequence, Tuple, - TypeVar, Union, ) @@ -32,7 +34,8 @@ pass -T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) +if TYPE_CHECKING: + from .types import T_Xarray def unzip(iterable): @@ -122,8 +125,8 @@ def make_meta(obj): def infer_template( - func: Callable[..., T_DSorDA], obj: Union[DataArray, Dataset], *args, **kwargs -) -> T_DSorDA: + func: Callable[..., T_Xarray], obj: Union[DataArray, Dataset], *args, **kwargs +) -> T_Xarray: """Infer return object by running the function on meta objects.""" meta_args = [make_meta(arg) for arg in (obj,) + args] @@ -162,12 +165,12 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping def map_blocks( - func: Callable[..., T_DSorDA], + func: Callable[..., T_Xarray], obj: Union[DataArray, Dataset], args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, template: Union[DataArray, Dataset] = None, -) -> T_DSorDA: +) -> T_Xarray: """Apply a function to each block of a DataArray or Dataset. .. warning:: diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index e0fe57a9fb0..31718267ee0 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -1,17 +1,14 @@ +from __future__ import annotations + from distutils.version import LooseVersion -from typing import TYPE_CHECKING, Generic, Hashable, Mapping, TypeVar, Union +from typing import Generic, Hashable, Mapping, Union import numpy as np from .options import _get_keep_attrs from .pdcompat import count_not_none from .pycompat import is_duck_dask_array - -if TYPE_CHECKING: - from .dataarray import DataArray # noqa: F401 - from .dataset import Dataset # noqa: F401 - -T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset") +from .types import T_Xarray def _get_alpha(com=None, span=None, halflife=None, alpha=None): @@ -79,7 +76,7 @@ def _get_center_of_mass(comass, span, halflife, alpha): return float(comass) -class RollingExp(Generic[T_DSorDA]): +class RollingExp(Generic[T_Xarray]): """ Exponentially-weighted moving window object. Similar to EWM in pandas @@ -103,16 +100,16 @@ class RollingExp(Generic[T_DSorDA]): def __init__( self, - obj: T_DSorDA, + obj: T_Xarray, windows: Mapping[Hashable, Union[int, float]], window_type: str = "span", ): - self.obj: T_DSorDA = obj + self.obj: T_Xarray = obj dim, window = next(iter(windows.items())) self.dim = dim self.alpha = _get_alpha(**{window_type: window}) - def mean(self, keep_attrs: bool = None) -> T_DSorDA: + def mean(self, keep_attrs: bool = None) -> T_Xarray: """ Exponentially weighted moving average. @@ -139,7 +136,7 @@ def mean(self, keep_attrs: bool = None) -> T_DSorDA: move_exp_nanmean, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs ) - def sum(self, keep_attrs: bool = None) -> T_DSorDA: + def sum(self, keep_attrs: bool = None) -> T_Xarray: """ Exponentially weighted moving sum. diff --git a/xarray/core/types.py b/xarray/core/types.py new file mode 100644 index 00000000000..26b1b381c30 --- /dev/null +++ b/xarray/core/types.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeVar, Union + +import numpy as np + +if TYPE_CHECKING: + from .common import DataWithCoords + from .dataarray import DataArray + from .dataset import Dataset + from .groupby import DataArrayGroupBy, GroupBy + from .npcompat import ArrayLike + from .variable import Variable + + try: + from dask.array import Array as DaskArray + except ImportError: + DaskArray = np.ndarray + +T_Dataset = TypeVar("T_Dataset", bound="Dataset") +T_DataArray = TypeVar("T_DataArray", bound="DataArray") +T_Variable = TypeVar("T_Variable", bound="Variable") +# Maybe we rename this to T_Data or something less Fortran-y? +T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") +T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") + +ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"] +DsCompatible = Union["Dataset", "DataArray", "Variable", "GroupBy", "ScalarOrArray"] +DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"] +VarCompatible = Union["Variable", "ScalarOrArray"] +GroupByIncompatible = Union["Variable", "GroupBy"] diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 2798a4ab956..f96cbe63d07 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import itertools import numbers @@ -5,6 +7,7 @@ from collections import defaultdict from datetime import timedelta from typing import ( + TYPE_CHECKING, Any, Dict, Hashable, @@ -13,7 +16,6 @@ Optional, Sequence, Tuple, - TypeVar, Union, ) @@ -66,17 +68,8 @@ # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) -VariableType = TypeVar("VariableType", bound="Variable") -"""Type annotation to be used when methods of Variable return self or a copy of self. -When called from an instance of a subclass, e.g. IndexVariable, mypy identifies the -output as an instance of the subclass. - -Usage:: - - class Variable: - def f(self: VariableType, ...) -> VariableType: - ... -""" +if TYPE_CHECKING: + from .types import T_Variable class MissingDimensionsError(ValueError): @@ -362,7 +355,7 @@ def data(self, data): self._data = data def astype( - self: VariableType, + self: T_Variable, dtype, *, order=None, @@ -370,7 +363,7 @@ def astype( subok=None, copy=None, keep_attrs=True, - ) -> VariableType: + ) -> T_Variable: """ Copy of the Variable object, with data cast to a specified type. @@ -775,7 +768,7 @@ def _broadcast_indexes_vectorized(self, key): return out_dims, VectorizedIndexer(tuple(out_key)), new_order - def __getitem__(self: VariableType, key) -> VariableType: + def __getitem__(self: T_Variable, key) -> T_Variable: """Return a new Variable object whose contents are consistent with getting the provided key from the underlying data. @@ -794,7 +787,7 @@ def __getitem__(self: VariableType, key) -> VariableType: data = np.moveaxis(data, range(len(new_order)), new_order) return self._finalize_indexing_result(dims, data) - def _finalize_indexing_result(self: VariableType, dims, data) -> VariableType: + def _finalize_indexing_result(self: T_Variable, dims, data) -> T_Variable: """Used by IndexVariable to return IndexVariable objects when possible.""" return self._replace(dims=dims, data=data) @@ -974,12 +967,12 @@ def copy(self, deep=True, data=None): return self._replace(data=data) def _replace( - self: VariableType, + self: T_Variable, dims=_default, data=_default, attrs=_default, encoding=_default, - ) -> VariableType: + ) -> T_Variable: if dims is _default: dims = copy.copy(self._dims) if data is _default: @@ -1101,7 +1094,7 @@ def to_numpy(self) -> np.ndarray: return data - def as_numpy(self: VariableType) -> VariableType: + def as_numpy(self: T_Variable) -> T_Variable: """Coerces wrapped data into a numpy array, returning a Variable.""" return self._replace(data=self.to_numpy()) @@ -1136,11 +1129,11 @@ def _to_dense(self): return self.copy(deep=False) def isel( - self: VariableType, + self: T_Variable, indexers: Mapping[Any, Any] = None, missing_dims: str = "raise", **indexers_kwargs: Any, - ) -> VariableType: + ) -> T_Variable: """Return a new array indexed along the specified dimension(s). Parameters diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index e8838b07157..c31b24f53b5 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -1,15 +1,9 @@ -from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, Union from . import duck_array_ops from .computation import dot from .pycompat import is_duck_dask_array - -if TYPE_CHECKING: - from .common import DataWithCoords # noqa: F401 - from .dataarray import DataArray, Dataset - -T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") - +from .types import T_Xarray _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """ Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s). @@ -59,7 +53,12 @@ """ -class Weighted(Generic[T_DataWithCoords]): +if TYPE_CHECKING: + from .dataarray import DataArray + from .dataset import Dataset + + +class Weighted(Generic[T_Xarray]): """An object that implements weighted operations. You should create a Weighted object by using the ``DataArray.weighted`` or @@ -73,7 +72,7 @@ class Weighted(Generic[T_DataWithCoords]): __slots__ = ("obj", "weights") - def __init__(self, obj: T_DataWithCoords, weights: "DataArray"): + def __init__(self, obj: T_Xarray, weights: "DataArray"): """ Create a Weighted object @@ -116,7 +115,7 @@ def _weight_check(w): else: _weight_check(weights.data) - self.obj: T_DataWithCoords = obj + self.obj: T_Xarray = obj self.weights: "DataArray" = weights def _check_dim(self, dim: Optional[Union[Hashable, Iterable[Hashable]]]): @@ -210,7 +209,7 @@ def sum_of_weights( self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, keep_attrs: Optional[bool] = None, - ) -> T_DataWithCoords: + ) -> T_Xarray: return self._implementation( self._sum_of_weights, dim=dim, keep_attrs=keep_attrs @@ -221,7 +220,7 @@ def sum( dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, keep_attrs: Optional[bool] = None, - ) -> T_DataWithCoords: + ) -> T_Xarray: return self._implementation( self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs @@ -232,7 +231,7 @@ def mean( dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, skipna: Optional[bool] = None, keep_attrs: Optional[bool] = None, - ) -> T_DataWithCoords: + ) -> T_Xarray: return self._implementation( self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py index b6b7f8cbac7..1ccede945bf 100644 --- a/xarray/util/generate_ops.py +++ b/xarray/util/generate_ops.py @@ -198,23 +198,24 @@ def inplace(): from .dataset import Dataset from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy from .npcompat import ArrayLike +from .types import ( + DaCompatible, + DsCompatible, + GroupByIncompatible, + ScalarOrArray, + T_DataArray, + T_Dataset, + T_Variable, + VarCompatible, +) from .variable import Variable try: from dask.array import Array as DaskArray except ImportError: DaskArray = np.ndarray +''' -# DatasetOpsMixin etc. are parent classes of Dataset etc. -T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin") -T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin") -T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin") - -ScalarOrArray = Union[ArrayLike, np.generic, np.ndarray, DaskArray] -DsCompatible = Union[Dataset, DataArray, Variable, GroupBy, ScalarOrArray] -DaCompatible = Union[DataArray, Variable, DataArrayGroupBy, ScalarOrArray] -VarCompatible = Union[Variable, ScalarOrArray] -GroupByIncompatible = Union[Variable, GroupBy]''' CLASS_PREAMBLE = """{newline} class {cls_name}: From 4f1e2d37b662079e830c9672400fabc19b44a376 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sat, 21 Aug 2021 20:32:22 -0700 Subject: [PATCH 17/20] Type annotate tests (#5728) * Type annotate lots of tests * fixes for newer numpy version * . * . --- doc/whats-new.rst | 3 + properties/test_encode_decode.py | 4 +- xarray/core/common.py | 8 +- xarray/core/computation.py | 2 +- xarray/core/coordinates.py | 2 +- xarray/core/dataarray.py | 35 ++-- xarray/core/dataset.py | 32 ++-- xarray/core/indexes.py | 8 +- xarray/core/merge.py | 18 +- xarray/core/rolling_exp.py | 4 +- xarray/core/utils.py | 4 +- xarray/core/variable.py | 16 +- xarray/tests/test_accessor_dt.py | 121 +++++++------ xarray/tests/test_accessor_str.py | 200 +++++++++++---------- xarray/tests/test_backends_api.py | 4 +- xarray/tests/test_backends_common.py | 2 +- xarray/tests/test_backends_file_manager.py | 35 ++-- xarray/tests/test_backends_locks.py | 2 +- xarray/tests/test_backends_lru_cache.py | 29 +-- xarray/tests/test_cftimeindex_resample.py | 6 +- xarray/tests/test_coarsen.py | 18 +- xarray/tests/test_coding.py | 18 +- xarray/tests/test_coding_strings.py | 32 ++-- xarray/tests/test_coding_times.py | 122 +++++++------ xarray/tests/test_computation.py | 118 ++++++------ xarray/tests/test_conventions.py | 52 +++--- xarray/tests/test_cupy.py | 6 +- xarray/tests/test_distributed.py | 14 +- xarray/tests/test_extensions.py | 17 +- xarray/tests/test_formatting.py | 59 +++--- xarray/tests/test_formatting_html.py | 42 +++-- xarray/tests/test_groupby.py | 84 ++++----- xarray/tests/test_indexes.py | 30 ++-- xarray/tests/test_indexing.py | 82 ++++----- xarray/tests/test_nputils.py | 4 +- xarray/tests/test_options.py | 34 ++-- xarray/tests/test_plugins.py | 32 ++-- xarray/tests/test_print_versions.py | 2 +- xarray/tests/test_testing.py | 10 +- xarray/tests/test_tutorial.py | 6 +- 40 files changed, 673 insertions(+), 644 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 942ee51fff9..c0d8dc5a367 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -63,6 +63,9 @@ Internal Changes By `Benoit Bovy `_. - Fix ``Mapping`` argument typing to allow mypy to pass on ``str`` keys (:pull:`5690`). By `Maximilian Roos `_. +- Annotate many of our tests, and fix some of the resulting typing errors. This will + also mean our typing annotations are tested as part of CI. (:pull:`5728`). + By `Maximilian Roos `_. - Improve the performance of reprs for large datasets or dataarrays. (:pull:`5661`) By `Jimmy Westling `_. - Use isort's `float_to_top` config. (:pull:`5695`). diff --git a/properties/test_encode_decode.py b/properties/test_encode_decode.py index 4b0643cb2fe..3ba037e28b0 100644 --- a/properties/test_encode_decode.py +++ b/properties/test_encode_decode.py @@ -25,7 +25,7 @@ @pytest.mark.slow @given(st.data(), an_array) -def test_CFMask_coder_roundtrip(data, arr): +def test_CFMask_coder_roundtrip(data, arr) -> None: names = data.draw( st.lists(st.text(), min_size=arr.ndim, max_size=arr.ndim, unique=True).map( tuple @@ -39,7 +39,7 @@ def test_CFMask_coder_roundtrip(data, arr): @pytest.mark.slow @given(st.data(), an_array) -def test_CFScaleOffset_coder_roundtrip(data, arr): +def test_CFScaleOffset_coder_roundtrip(data, arr) -> None: names = data.draw( st.lists(st.text(), min_size=arr.ndim, max_size=arr.ndim, unique=True).map( tuple diff --git a/xarray/core/common.py b/xarray/core/common.py index 74763829856..0f2b58d594a 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -411,7 +411,7 @@ def get_index(self, key: Hashable) -> pd.Index: return pd.Index(range(self.sizes[key]), name=key) def _calc_assign_results( - self: C, kwargs: Mapping[Hashable, Union[T, Callable[[C], T]]] + self: C, kwargs: Mapping[Any, Union[T, Callable[[C], T]]] ) -> Dict[Hashable, T]: return {k: v(self) if callable(v) else v for k, v in kwargs.items()} @@ -820,7 +820,7 @@ def rolling( self, dim: Mapping[Any, int] = None, min_periods: int = None, - center: Union[bool, Mapping[Hashable, bool]] = False, + center: Union[bool, Mapping[Any, bool]] = False, **window_kwargs: int, ): """ @@ -935,7 +935,7 @@ def coarsen( self, dim: Mapping[Any, int] = None, boundary: str = "exact", - side: Union[str, Mapping[Hashable, str]] = "left", + side: Union[str, Mapping[Any, str]] = "left", coord_func: str = "mean", **window_kwargs: int, ): @@ -1520,7 +1520,7 @@ def __getitem__(self, value): def full_like( other: "Dataset", fill_value, - dtype: Union[DTypeLike, Mapping[Hashable, DTypeLike]] = None, + dtype: Union[DTypeLike, Mapping[Any, DTypeLike]] = None, ) -> "Dataset": ... diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 9278577cbd6..b891f31f9c5 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -367,7 +367,7 @@ def _as_variables_or_variable(arg): def _unpack_dict_tuples( - result_vars: Mapping[Hashable, Tuple[Variable, ...]], num_outputs: int + result_vars: Mapping[Any, Tuple[Variable, ...]], num_outputs: int ) -> Tuple[Dict[Hashable, Variable], ...]: out: Tuple[Dict[Hashable, Variable], ...] = tuple({} for _ in range(num_outputs)) for name, values in result_vars.items(): diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 56afdb9774a..9dfd64e9c99 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -31,7 +31,7 @@ _THIS_ARRAY = ReprObject("") -class Coordinates(Mapping[Hashable, "DataArray"]): +class Coordinates(Mapping[Any, "DataArray"]): __slots__ = () def __getitem__(self, key: Hashable) -> "DataArray": diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index e0c1a1c9d17..fd3f2d6529e 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -366,7 +366,7 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic): def __init__( self, data: Any = dtypes.NA, - coords: Union[Sequence[Tuple], Mapping[Hashable, Any], None] = None, + coords: Union[Sequence[Tuple], Mapping[Any, Any], None] = None, dims: Union[Hashable, Sequence[Hashable], None] = None, name: Hashable = None, attrs: Mapping = None, @@ -788,7 +788,8 @@ def loc(self) -> _LocIndexer: return _LocIndexer(self) @property - def attrs(self) -> Dict[Hashable, Any]: + # Key type needs to be `Any` because of mypy#4167 + def attrs(self) -> Dict[Any, Any]: """Dictionary storing arbitrary metadata with this array.""" return self.variable.attrs @@ -1068,7 +1069,7 @@ def chunk( int, Tuple[int, ...], Tuple[Tuple[int, ...], ...], - Mapping[Hashable, Union[None, int, Tuple[int, ...]]], + Mapping[Any, Union[None, int, Tuple[int, ...]]], ] = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) name_prefix: str = "xarray-", token: str = None, @@ -1312,7 +1313,7 @@ def sel( def head( self, - indexers: Union[Mapping[Hashable, int], int] = None, + indexers: Union[Mapping[Any, int], int] = None, **indexers_kwargs: Any, ) -> "DataArray": """Return a new DataArray whose data is given by the the first `n` @@ -1329,7 +1330,7 @@ def head( def tail( self, - indexers: Union[Mapping[Hashable, int], int] = None, + indexers: Union[Mapping[Any, int], int] = None, **indexers_kwargs: Any, ) -> "DataArray": """Return a new DataArray whose data is given by the the last `n` @@ -1346,7 +1347,7 @@ def tail( def thin( self, - indexers: Union[Mapping[Hashable, int], int] = None, + indexers: Union[Mapping[Any, int], int] = None, **indexers_kwargs: Any, ) -> "DataArray": """Return a new DataArray whose data is given by each `n` value @@ -1778,7 +1779,7 @@ def interp_like( def rename( self, - new_name_or_name_dict: Union[Hashable, Mapping[Hashable, Hashable]] = None, + new_name_or_name_dict: Union[Hashable, Mapping[Any, Hashable]] = None, **names: Hashable, ) -> "DataArray": """Returns a new DataArray with renamed coordinates or a new name. @@ -1874,7 +1875,7 @@ def swap_dims( def expand_dims( self, - dim: Union[None, Hashable, Sequence[Hashable], Mapping[Hashable, Any]] = None, + dim: Union[None, Hashable, Sequence[Hashable], Mapping[Any, Any]] = None, axis=None, **dim_kwargs: Any, ) -> "DataArray": @@ -1926,7 +1927,7 @@ def expand_dims( def set_index( self, - indexes: Mapping[Hashable, Union[Hashable, Sequence[Hashable]]] = None, + indexes: Mapping[Any, Union[Hashable, Sequence[Hashable]]] = None, append: bool = False, **indexes_kwargs: Union[Hashable, Sequence[Hashable]], ) -> "DataArray": @@ -2014,7 +2015,7 @@ def reset_index( def reorder_levels( self, - dim_order: Mapping[Hashable, Sequence[int]] = None, + dim_order: Mapping[Any, Sequence[int]] = None, **dim_order_kwargs: Sequence[int], ) -> "DataArray": """Rearrange index levels using input order. @@ -2049,7 +2050,7 @@ def reorder_levels( def stack( self, - dimensions: Mapping[Hashable, Sequence[Hashable]] = None, + dimensions: Mapping[Any, Sequence[Hashable]] = None, **dimensions_kwargs: Sequence[Hashable], ) -> "DataArray": """ @@ -3868,17 +3869,13 @@ def polyfit( def pad( self, - pad_width: Mapping[Hashable, Union[int, Tuple[int, int]]] = None, + pad_width: Mapping[Any, Union[int, Tuple[int, int]]] = None, mode: str = "constant", - stat_length: Union[ - int, Tuple[int, int], Mapping[Hashable, Tuple[int, int]] - ] = None, + stat_length: Union[int, Tuple[int, int], Mapping[Any, Tuple[int, int]]] = None, constant_values: Union[ - int, Tuple[int, int], Mapping[Hashable, Tuple[int, int]] - ] = None, - end_values: Union[ - int, Tuple[int, int], Mapping[Hashable, Tuple[int, int]] + int, Tuple[int, int], Mapping[Any, Tuple[int, int]] ] = None, + end_values: Union[int, Tuple[int, int], Mapping[Any, Tuple[int, int]]] = None, reflect_type: str = None, **pad_width_kwargs: Any, ) -> "DataArray": diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 98374fa2ba3..43f83c77bc8 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -210,7 +210,7 @@ def calculate_dimensions(variables: Mapping[Any, Variable]) -> Dict[Hashable, in def merge_indexes( - indexes: Mapping[Hashable, Union[Hashable, Sequence[Hashable]]], + indexes: Mapping[Any, Union[Hashable, Sequence[Hashable]]], variables: Mapping[Any, Variable], coord_names: Set[Hashable], append: bool = False, @@ -510,7 +510,7 @@ def _initialize_feasible(lb, ub): return param_defaults, bounds_defaults -class DataVariables(Mapping[Hashable, "DataArray"]): +class DataVariables(Mapping[Any, "DataArray"]): __slots__ = ("_dataset",) def __init__(self, dataset: "Dataset"): @@ -2110,7 +2110,7 @@ def chunk( chunks: Union[ int, str, - Mapping[Hashable, Union[None, int, str, Tuple[int, ...]]], + Mapping[Any, Union[None, int, str, Tuple[int, ...]]], ] = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) name_prefix: str = "xarray-", token: str = None, @@ -2485,7 +2485,7 @@ def sel( def head( self, - indexers: Union[Mapping[Hashable, int], int] = None, + indexers: Union[Mapping[Any, int], int] = None, **indexers_kwargs: Any, ) -> "Dataset": """Returns a new dataset with the first `n` values of each array @@ -2531,7 +2531,7 @@ def head( def tail( self, - indexers: Union[Mapping[Hashable, int], int] = None, + indexers: Union[Mapping[Any, int], int] = None, **indexers_kwargs: Any, ) -> "Dataset": """Returns a new dataset with the last `n` values of each array @@ -2580,7 +2580,7 @@ def tail( def thin( self, - indexers: Union[Mapping[Hashable, int], int] = None, + indexers: Union[Mapping[Any, int], int] = None, **indexers_kwargs: Any, ) -> "Dataset": """Returns a new dataset with each array indexed along every `n`-th @@ -3559,7 +3559,7 @@ def swap_dims( def expand_dims( self, - dim: Union[None, Hashable, Sequence[Hashable], Mapping[Hashable, Any]] = None, + dim: Union[None, Hashable, Sequence[Hashable], Mapping[Any, Any]] = None, axis: Union[None, int, Sequence[int]] = None, **dim_kwargs: Any, ) -> "Dataset": @@ -3691,7 +3691,7 @@ def expand_dims( def set_index( self, - indexes: Mapping[Hashable, Union[Hashable, Sequence[Hashable]]] = None, + indexes: Mapping[Any, Union[Hashable, Sequence[Hashable]]] = None, append: bool = False, **indexes_kwargs: Union[Hashable, Sequence[Hashable]], ) -> "Dataset": @@ -3789,7 +3789,7 @@ def reset_index( def reorder_levels( self, - dim_order: Mapping[Hashable, Sequence[int]] = None, + dim_order: Mapping[Any, Sequence[int]] = None, **dim_order_kwargs: Sequence[int], ) -> "Dataset": """Rearrange index levels using input order. @@ -3858,7 +3858,7 @@ def _stack_once(self, dims, new_dim): def stack( self, - dimensions: Mapping[Hashable, Sequence[Hashable]] = None, + dimensions: Mapping[Any, Sequence[Hashable]] = None, **dimensions_kwargs: Sequence[Hashable], ) -> "Dataset": """ @@ -6929,17 +6929,13 @@ def polyfit( def pad( self, - pad_width: Mapping[Hashable, Union[int, Tuple[int, int]]] = None, + pad_width: Mapping[Any, Union[int, Tuple[int, int]]] = None, mode: str = "constant", - stat_length: Union[ - int, Tuple[int, int], Mapping[Hashable, Tuple[int, int]] - ] = None, + stat_length: Union[int, Tuple[int, int], Mapping[Any, Tuple[int, int]]] = None, constant_values: Union[ - int, Tuple[int, int], Mapping[Hashable, Tuple[int, int]] - ] = None, - end_values: Union[ - int, Tuple[int, int], Mapping[Hashable, Tuple[int, int]] + int, Tuple[int, int], Mapping[Any, Tuple[int, int]] ] = None, + end_values: Union[int, Tuple[int, int], Mapping[Any, Tuple[int, int]]] = None, reflect_type: str = None, **pad_width_kwargs: Any, ) -> "Dataset": diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 429c37af588..95b6ccaad30 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -34,7 +34,7 @@ class Index: @classmethod def from_variables( - cls, variables: Mapping[Hashable, "Variable"] + cls, variables: Mapping[Any, "Variable"] ) -> Tuple["Index", Optional[IndexVars]]: # pragma: no cover raise NotImplementedError() @@ -153,7 +153,7 @@ def __init__(self, array: Any, dim: Hashable): self.dim = dim @classmethod - def from_variables(cls, variables: Mapping[Hashable, "Variable"]): + def from_variables(cls, variables: Mapping[Any, "Variable"]): from .variable import IndexVariable if len(variables) != 1: @@ -291,7 +291,7 @@ def _create_variables_from_multiindex(index, dim, level_meta=None): class PandasMultiIndex(PandasIndex): @classmethod - def from_variables(cls, variables: Mapping[Hashable, "Variable"]): + def from_variables(cls, variables: Mapping[Any, "Variable"]): if any([var.ndim != 1 for var in variables.values()]): raise ValueError("PandasMultiIndex only accepts 1-dimensional variables") @@ -499,7 +499,7 @@ def isel_variable_and_index( name: Hashable, variable: "Variable", index: Index, - indexers: Mapping[Hashable, Union[int, slice, np.ndarray, "Variable"]], + indexers: Mapping[Any, Union[int, slice, np.ndarray, "Variable"]], ) -> Tuple["Variable", Optional[Index]]: """Index a Variable and an Index together. diff --git a/xarray/core/merge.py b/xarray/core/merge.py index eaa5b62b2a9..a89e767826d 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import ( TYPE_CHECKING, AbstractSet, @@ -38,9 +40,9 @@ Tuple[DimsLike, ArrayLike, Mapping, Mapping], ] XarrayValue = Union[DataArray, Variable, VariableLike] - DatasetLike = Union[Dataset, Mapping[Hashable, XarrayValue]] + DatasetLike = Union[Dataset, Mapping[Any, XarrayValue]] CoercibleValue = Union[XarrayValue, pd.Series, pd.DataFrame] - CoercibleMapping = Union[Dataset, Mapping[Hashable, CoercibleValue]] + CoercibleMapping = Union[Dataset, Mapping[Any, CoercibleValue]] PANDAS_TYPES = (pd.Series, pd.DataFrame, pdcompat.Panel) @@ -253,7 +255,7 @@ def merge_collected( def collect_variables_and_indexes( - list_of_mappings: "List[DatasetLike]", + list_of_mappings: List[DatasetLike], ) -> Dict[Hashable, List[MergeElement]]: """Collect variables and indexes from list of mappings of xarray objects. @@ -292,12 +294,14 @@ def append_all(variables, indexes): append_all(coords, indexes) variable = as_variable(variable, name=name) + if variable.dims == (name,): - variable = variable.to_index_variable() + idx_variable = variable.to_index_variable() index = variable._to_xindex() + append(name, idx_variable, index) else: index = None - append(name, variable, index) + append(name, variable, index) return grouped @@ -455,7 +459,7 @@ def merge_coords( compat: str = "minimal", join: str = "outer", priority_arg: Optional[int] = None, - indexes: Optional[Mapping[Hashable, Index]] = None, + indexes: Optional[Mapping[Any, Index]] = None, fill_value: object = dtypes.NA, ) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: """Merge coordinate variables. @@ -578,7 +582,7 @@ def merge_core( combine_attrs: Optional[str] = "override", priority_arg: Optional[int] = None, explicit_coords: Optional[Sequence] = None, - indexes: Optional[Mapping[Hashable, Any]] = None, + indexes: Optional[Mapping[Any, Any]] = None, fill_value: object = dtypes.NA, ) -> _MergeResult: """Core logic for merging labeled objects. diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index 31718267ee0..2de555d422d 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -1,7 +1,7 @@ from __future__ import annotations from distutils.version import LooseVersion -from typing import Generic, Hashable, Mapping, Union +from typing import Any, Generic, Mapping, Union import numpy as np @@ -101,7 +101,7 @@ class RollingExp(Generic[T_Xarray]): def __init__( self, obj: T_Xarray, - windows: Mapping[Hashable, Union[int, float]], + windows: Mapping[Any, Union[int, float]], window_type: str = "span", ): self.obj: T_Xarray = obj diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 57b7035c940..77d973f613f 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -274,7 +274,7 @@ def is_duck_array(value: Any) -> bool: def either_dict_or_kwargs( - pos_kwargs: Optional[Mapping[Hashable, T]], + pos_kwargs: Optional[Mapping[Any, T]], kw_kwargs: Mapping[str, T], func_name: str, ) -> Mapping[Hashable, T]: @@ -817,7 +817,7 @@ def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable: def drop_dims_from_indexers( indexers: Mapping[Any, Any], - dims: Union[list, Mapping[Hashable, int]], + dims: Union[list, Mapping[Any, int]], missing_dims: str, ) -> Mapping[Hashable, Any]: """Depending on the setting of missing_dims, drop any dimensions from indexers that diff --git a/xarray/core/variable.py b/xarray/core/variable.py index f96cbe63d07..191bb4059f5 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -79,7 +79,7 @@ class MissingDimensionsError(ValueError): # TODO: move this to an xarray.exceptions module? -def as_variable(obj, name=None) -> "Union[Variable, IndexVariable]": +def as_variable(obj, name=None) -> Union[Variable, IndexVariable]: """Convert an object into a Variable. Parameters @@ -1251,7 +1251,7 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): def _pad_options_dim_to_index( self, - pad_option: Mapping[Hashable, Union[int, Tuple[int, int]]], + pad_option: Mapping[Any, Union[int, Tuple[int, int]]], fill_with_shape=False, ): if fill_with_shape: @@ -1263,17 +1263,13 @@ def _pad_options_dim_to_index( def pad( self, - pad_width: Mapping[Hashable, Union[int, Tuple[int, int]]] = None, + pad_width: Mapping[Any, Union[int, Tuple[int, int]]] = None, mode: str = "constant", - stat_length: Union[ - int, Tuple[int, int], Mapping[Hashable, Tuple[int, int]] - ] = None, + stat_length: Union[int, Tuple[int, int], Mapping[Any, Tuple[int, int]]] = None, constant_values: Union[ - int, Tuple[int, int], Mapping[Hashable, Tuple[int, int]] - ] = None, - end_values: Union[ - int, Tuple[int, int], Mapping[Hashable, Tuple[int, int]] + int, Tuple[int, int], Mapping[Any, Tuple[int, int]] ] = None, + end_values: Union[int, Tuple[int, int], Mapping[Any, Tuple[int, int]]] = None, reflect_type: str = None, **pad_width_kwargs: Any, ): diff --git a/xarray/tests/test_accessor_dt.py b/xarray/tests/test_accessor_dt.py index 62da3bab2cd..135aa058439 100644 --- a/xarray/tests/test_accessor_dt.py +++ b/xarray/tests/test_accessor_dt.py @@ -69,7 +69,7 @@ def setup(self): "is_leap_year", ], ) - def test_field_access(self, field): + def test_field_access(self, field) -> None: if LooseVersion(pd.__version__) >= "1.1.0" and field in ["week", "weekofyear"]: data = self.times.isocalendar()["week"] @@ -96,7 +96,7 @@ def test_field_access(self, field): ("weekday", "day"), ], ) - def test_isocalendar(self, field, pandas_field): + def test_isocalendar(self, field, pandas_field) -> None: if LooseVersion(pd.__version__) < "1.1.0": with pytest.raises( @@ -114,12 +114,12 @@ def test_isocalendar(self, field, pandas_field): actual = self.data.time.dt.isocalendar()[field] assert_equal(expected, actual) - def test_strftime(self): + def test_strftime(self) -> None: assert ( "2000-01-01 01:00:00" == self.data.time.dt.strftime("%Y-%m-%d %H:%M:%S")[1] ) - def test_not_datetime_type(self): + def test_not_datetime_type(self) -> None: nontime_data = self.data.copy() int_data = np.arange(len(self.data.time)).astype("int8") nontime_data = nontime_data.assign_coords(time=int_data) @@ -156,7 +156,7 @@ def test_not_datetime_type(self): "is_leap_year", ], ) - def test_dask_field_access(self, field): + def test_dask_field_access(self, field) -> None: import dask.array as da expected = getattr(self.times_data.dt, field) @@ -182,7 +182,7 @@ def test_dask_field_access(self, field): "weekday", ], ) - def test_isocalendar_dask(self, field): + def test_isocalendar_dask(self, field) -> None: import dask.array as da if LooseVersion(pd.__version__) < "1.1.0": @@ -216,7 +216,7 @@ def test_isocalendar_dask(self, field): ("strftime", "%Y-%m-%d %H:%M:%S"), ], ) - def test_dask_accessor_method(self, method, parameters): + def test_dask_accessor_method(self, method, parameters) -> None: import dask.array as da expected = getattr(self.times_data.dt, method)(parameters) @@ -232,31 +232,32 @@ def test_dask_accessor_method(self, method, parameters): assert_chunks_equal(actual, dask_times_2d) assert_equal(actual.compute(), expected.compute()) - def test_seasons(self): + def test_seasons(self) -> None: dates = pd.date_range(start="2000/01/01", freq="M", periods=12) dates = xr.DataArray(dates) - seasons = [ - "DJF", - "DJF", - "MAM", - "MAM", - "MAM", - "JJA", - "JJA", - "JJA", - "SON", - "SON", - "SON", - "DJF", - ] - seasons = xr.DataArray(seasons) + seasons = xr.DataArray( + [ + "DJF", + "DJF", + "MAM", + "MAM", + "MAM", + "JJA", + "JJA", + "JJA", + "SON", + "SON", + "SON", + "DJF", + ] + ) assert_array_equal(seasons.values, dates.dt.season.values) @pytest.mark.parametrize( "method, parameters", [("floor", "D"), ("ceil", "D"), ("round", "D")] ) - def test_accessor_method(self, method, parameters): + def test_accessor_method(self, method, parameters) -> None: dates = pd.date_range("2014-01-01", "2014-05-01", freq="H") xdates = xr.DataArray(dates, dims=["time"]) expected = getattr(dates, method)(parameters) @@ -288,7 +289,7 @@ def setup(self): name="data", ) - def test_not_datetime_type(self): + def test_not_datetime_type(self) -> None: nontime_data = self.data.copy() int_data = np.arange(len(self.data.time)).astype("int8") nontime_data = nontime_data.assign_coords(time=int_data) @@ -298,7 +299,7 @@ def test_not_datetime_type(self): @pytest.mark.parametrize( "field", ["days", "seconds", "microseconds", "nanoseconds"] ) - def test_field_access(self, field): + def test_field_access(self, field) -> None: expected = xr.DataArray( getattr(self.times, field), name=field, coords=[self.times], dims=["time"] ) @@ -308,7 +309,7 @@ def test_field_access(self, field): @pytest.mark.parametrize( "method, parameters", [("floor", "D"), ("ceil", "D"), ("round", "D")] ) - def test_accessor_methods(self, method, parameters): + def test_accessor_methods(self, method, parameters) -> None: dates = pd.timedelta_range(start="1 day", end="30 days", freq="6H") xdates = xr.DataArray(dates, dims=["time"]) expected = getattr(dates, method)(parameters) @@ -319,7 +320,7 @@ def test_accessor_methods(self, method, parameters): @pytest.mark.parametrize( "field", ["days", "seconds", "microseconds", "nanoseconds"] ) - def test_dask_field_access(self, field): + def test_dask_field_access(self, field) -> None: import dask.array as da expected = getattr(self.times_data.dt, field) @@ -340,7 +341,7 @@ def test_dask_field_access(self, field): @pytest.mark.parametrize( "method, parameters", [("floor", "D"), ("ceil", "D"), ("round", "D")] ) - def test_dask_accessor_method(self, method, parameters): + def test_dask_accessor_method(self, method, parameters) -> None: import dask.array as da expected = getattr(self.times_data.dt, method)(parameters) @@ -410,7 +411,7 @@ def times_3d(times): @pytest.mark.parametrize( "field", ["year", "month", "day", "hour", "dayofyear", "dayofweek"] ) -def test_field_access(data, field): +def test_field_access(data, field) -> None: if field == "dayofyear" or field == "dayofweek": pytest.importorskip("cftime", minversion="1.0.2.1") result = getattr(data.time.dt, field) @@ -425,7 +426,7 @@ def test_field_access(data, field): @requires_cftime -def test_isocalendar_cftime(data): +def test_isocalendar_cftime(data) -> None: with pytest.raises( AttributeError, match=r"'CFTimeIndex' object has no attribute 'isocalendar'" @@ -434,7 +435,7 @@ def test_isocalendar_cftime(data): @requires_cftime -def test_date_cftime(data): +def test_date_cftime(data) -> None: with pytest.raises( AttributeError, @@ -445,7 +446,7 @@ def test_date_cftime(data): @requires_cftime @pytest.mark.filterwarnings("ignore::RuntimeWarning") -def test_cftime_strftime_access(data): +def test_cftime_strftime_access(data) -> None: """compare cftime formatting against datetime formatting""" date_format = "%Y%m%d%H" result = data.time.dt.strftime(date_format) @@ -464,7 +465,7 @@ def test_cftime_strftime_access(data): @pytest.mark.parametrize( "field", ["year", "month", "day", "hour", "dayofyear", "dayofweek"] ) -def test_dask_field_access_1d(data, field): +def test_dask_field_access_1d(data, field) -> None: import dask.array as da if field == "dayofyear" or field == "dayofweek": @@ -486,7 +487,7 @@ def test_dask_field_access_1d(data, field): @pytest.mark.parametrize( "field", ["year", "month", "day", "hour", "dayofyear", "dayofweek"] ) -def test_dask_field_access(times_3d, data, field): +def test_dask_field_access(times_3d, data, field) -> None: import dask.array as da if field == "dayofyear" or field == "dayofweek": @@ -514,24 +515,26 @@ def cftime_date_type(calendar): @requires_cftime -def test_seasons(cftime_date_type): - dates = np.array([cftime_date_type(2000, month, 15) for month in range(1, 13)]) - dates = xr.DataArray(dates) - seasons = [ - "DJF", - "DJF", - "MAM", - "MAM", - "MAM", - "JJA", - "JJA", - "JJA", - "SON", - "SON", - "SON", - "DJF", - ] - seasons = xr.DataArray(seasons) +def test_seasons(cftime_date_type) -> None: + dates = xr.DataArray( + np.array([cftime_date_type(2000, month, 15) for month in range(1, 13)]) + ) + seasons = xr.DataArray( + [ + "DJF", + "DJF", + "MAM", + "MAM", + "MAM", + "JJA", + "JJA", + "JJA", + "SON", + "SON", + "SON", + "DJF", + ] + ) assert_array_equal(seasons.values, dates.dt.season.values) @@ -549,7 +552,9 @@ def cftime_rounding_dataarray(cftime_date_type): @requires_cftime @requires_dask @pytest.mark.parametrize("use_dask", [False, True]) -def test_cftime_floor_accessor(cftime_rounding_dataarray, cftime_date_type, use_dask): +def test_cftime_floor_accessor( + cftime_rounding_dataarray, cftime_date_type, use_dask +) -> None: import dask.array as da freq = "D" @@ -580,7 +585,9 @@ def test_cftime_floor_accessor(cftime_rounding_dataarray, cftime_date_type, use_ @requires_cftime @requires_dask @pytest.mark.parametrize("use_dask", [False, True]) -def test_cftime_ceil_accessor(cftime_rounding_dataarray, cftime_date_type, use_dask): +def test_cftime_ceil_accessor( + cftime_rounding_dataarray, cftime_date_type, use_dask +) -> None: import dask.array as da freq = "D" @@ -611,7 +618,9 @@ def test_cftime_ceil_accessor(cftime_rounding_dataarray, cftime_date_type, use_d @requires_cftime @requires_dask @pytest.mark.parametrize("use_dask", [False, True]) -def test_cftime_round_accessor(cftime_rounding_dataarray, cftime_date_type, use_dask): +def test_cftime_round_accessor( + cftime_rounding_dataarray, cftime_date_type, use_dask +) -> None: import dask.array as da freq = "D" diff --git a/xarray/tests/test_accessor_str.py b/xarray/tests/test_accessor_str.py index 519ca762c41..e3c45d732e4 100644 --- a/xarray/tests/test_accessor_str.py +++ b/xarray/tests/test_accessor_str.py @@ -37,6 +37,8 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# type: ignore[assignment] + import re import numpy as np @@ -53,7 +55,7 @@ def dtype(request): @requires_dask -def test_dask(): +def test_dask() -> None: import dask.array as da arr = da.from_array(["a", "b", "c"], chunks=-1) @@ -65,7 +67,7 @@ def test_dask(): assert_equal(result, expected) -def test_count(dtype): +def test_count(dtype) -> None: values = xr.DataArray(["foo", "foofoo", "foooofooofommmfoo"]).astype(dtype) pat_str = dtype(r"f[o]+") pat_re = re.compile(pat_str) @@ -81,7 +83,7 @@ def test_count(dtype): assert_equal(result_re, expected) -def test_count_broadcast(dtype): +def test_count_broadcast(dtype) -> None: values = xr.DataArray(["foo", "foofoo", "foooofooofommmfoo"]).astype(dtype) pat_str = np.array([r"f[o]+", r"o", r"m"]).astype(dtype) pat_re = np.array([re.compile(x) for x in pat_str]) @@ -97,7 +99,7 @@ def test_count_broadcast(dtype): assert_equal(result_re, expected) -def test_contains(dtype): +def test_contains(dtype) -> None: values = xr.DataArray(["Foo", "xYz", "fOOomMm__fOo", "MMM_"]).astype(dtype) # case insensitive using regex @@ -141,7 +143,7 @@ def test_contains(dtype): values.str.contains(pat_re, regex=False) -def test_contains_broadcast(dtype): +def test_contains_broadcast(dtype) -> None: values = xr.DataArray(["Foo", "xYz", "fOOomMm__fOo", "MMM_"], dims="X").astype( dtype ) @@ -208,7 +210,7 @@ def test_contains_broadcast(dtype): assert_equal(result, expected) -def test_starts_ends_with(dtype): +def test_starts_ends_with(dtype) -> None: values = xr.DataArray(["om", "foo_nom", "nom", "bar_foo", "foo"]).astype(dtype) result = values.str.startswith("foo") @@ -222,7 +224,7 @@ def test_starts_ends_with(dtype): assert_equal(result, expected) -def test_starts_ends_with_broadcast(dtype): +def test_starts_ends_with_broadcast(dtype) -> None: values = xr.DataArray( ["om", "foo_nom", "nom", "bar_foo", "foo_bar"], dims="X" ).astype(dtype) @@ -245,7 +247,7 @@ def test_starts_ends_with_broadcast(dtype): assert_equal(result, expected) -def test_case_bytes(): +def test_case_bytes() -> None: value = xr.DataArray(["SOme wOrd"]).astype(np.bytes_) exp_capitalized = xr.DataArray(["Some word"]).astype(np.bytes_) @@ -273,7 +275,7 @@ def test_case_bytes(): assert_equal(res_uppered, exp_uppered) -def test_case_str(): +def test_case_str() -> None: # This string includes some unicode characters # that are common case management corner cases value = xr.DataArray(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.unicode_) @@ -331,7 +333,7 @@ def test_case_str(): assert_equal(res_norm_nfkd, exp_norm_nfkd) -def test_replace(dtype): +def test_replace(dtype) -> None: values = xr.DataArray(["fooBAD__barBAD"], dims=["x"]).astype(dtype) result = values.str.replace("BAD[_]*", "") expected = xr.DataArray(["foobar"], dims=["x"]).astype(dtype) @@ -385,7 +387,7 @@ def test_replace(dtype): assert_equal(result, expected) -def test_replace_callable(): +def test_replace_callable() -> None: values = xr.DataArray(["fooBAD__barBAD"]) # test with callable @@ -421,7 +423,7 @@ def test_replace_callable(): assert_equal(result, exp) -def test_replace_unicode(): +def test_replace_unicode() -> None: # flags + unicode values = xr.DataArray([b"abcd,\xc3\xa0".decode("utf-8")]) expected = xr.DataArray([b"abcd, \xc3\xa0".decode("utf-8")]) @@ -445,7 +447,7 @@ def test_replace_unicode(): assert_equal(result, expected) -def test_replace_compiled_regex(dtype): +def test_replace_compiled_regex(dtype) -> None: values = xr.DataArray(["fooBAD__barBAD"], dims=["x"]).astype(dtype) # test with compiled regex @@ -507,7 +509,7 @@ def test_replace_compiled_regex(dtype): assert_equal(result, expected) -def test_replace_literal(dtype): +def test_replace_literal(dtype) -> None: # GH16808 literal replace (regex=False vs regex=True) values = xr.DataArray(["f.o", "foo"], dims=["X"]).astype(dtype) expected = xr.DataArray(["bao", "bao"], dims=["X"]).astype(dtype) @@ -550,7 +552,7 @@ def test_replace_literal(dtype): values.str.replace(compiled_pat, "", regex=False) -def test_extract_extractall_findall_empty_raises(dtype): +def test_extract_extractall_findall_empty_raises(dtype) -> None: pat_str = dtype(r".*") pat_re = re.compile(pat_str) @@ -575,7 +577,7 @@ def test_extract_extractall_findall_empty_raises(dtype): value.str.findall(pat=pat_re) -def test_extract_multi_None_raises(dtype): +def test_extract_multi_None_raises(dtype) -> None: pat_str = r"(\w+)_(\d+)" pat_re = re.compile(pat_str) @@ -594,7 +596,7 @@ def test_extract_multi_None_raises(dtype): value.str.extract(pat=pat_re, dim=None) -def test_extract_extractall_findall_case_re_raises(dtype): +def test_extract_extractall_findall_case_re_raises(dtype) -> None: pat_str = r".*" pat_re = re.compile(pat_str) @@ -631,7 +633,7 @@ def test_extract_extractall_findall_case_re_raises(dtype): value.str.findall(pat=pat_re, case=False) -def test_extract_extractall_name_collision_raises(dtype): +def test_extract_extractall_name_collision_raises(dtype) -> None: pat_str = r"(\w+)" pat_re = re.compile(pat_str) @@ -674,7 +676,7 @@ def test_extract_extractall_name_collision_raises(dtype): value.str.extractall(pat=pat_re, group_dim="ZZ", match_dim="ZZ") -def test_extract_single_case(dtype): +def test_extract_single_case(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") pat_re = re.compile(pat_re) @@ -720,7 +722,7 @@ def test_extract_single_case(dtype): assert_equal(res_re_dim, targ_dim) -def test_extract_single_nocase(dtype): +def test_extract_single_nocase(dtype) -> None: pat_str = r"(\w+)?_Xy_\d*" pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") pat_re = re.compile(pat_re, flags=re.IGNORECASE) @@ -760,7 +762,7 @@ def test_extract_single_nocase(dtype): assert_equal(res_re_dim, targ_dim) -def test_extract_multi_case(dtype): +def test_extract_multi_case(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") pat_re = re.compile(pat_re) @@ -798,7 +800,7 @@ def test_extract_multi_case(dtype): assert_equal(res_str_case, expected) -def test_extract_multi_nocase(dtype): +def test_extract_multi_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") pat_re = re.compile(pat_re, flags=re.IGNORECASE) @@ -833,7 +835,7 @@ def test_extract_multi_nocase(dtype): assert_equal(res_re, expected) -def test_extract_broadcast(dtype): +def test_extract_broadcast(dtype) -> None: value = xr.DataArray( ["a_Xy_0", "ab_xY_10", "abc_Xy_01"], dims=["X"], @@ -862,7 +864,7 @@ def test_extract_broadcast(dtype): assert_equal(res_re, expected) -def test_extractall_single_single_case(dtype): +def test_extractall_single_single_case(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") pat_re = re.compile(pat_re) @@ -892,7 +894,7 @@ def test_extractall_single_single_case(dtype): assert_equal(res_str_case, expected) -def test_extractall_single_single_nocase(dtype): +def test_extractall_single_single_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") pat_re = re.compile(pat_re, flags=re.I) @@ -919,7 +921,7 @@ def test_extractall_single_single_nocase(dtype): assert_equal(res_re, expected) -def test_extractall_single_multi_case(dtype): +def test_extractall_single_multi_case(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") pat_re = re.compile(pat_re) @@ -963,7 +965,7 @@ def test_extractall_single_multi_case(dtype): assert_equal(res_str_case, expected) -def test_extractall_single_multi_nocase(dtype): +def test_extractall_single_multi_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") pat_re = re.compile(pat_re, flags=re.I) @@ -1008,7 +1010,7 @@ def test_extractall_single_multi_nocase(dtype): assert_equal(res_re, expected) -def test_extractall_multi_single_case(dtype): +def test_extractall_multi_single_case(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") pat_re = re.compile(pat_re) @@ -1041,7 +1043,7 @@ def test_extractall_multi_single_case(dtype): assert_equal(res_str_case, expected) -def test_extractall_multi_single_nocase(dtype): +def test_extractall_multi_single_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") pat_re = re.compile(pat_re, flags=re.I) @@ -1071,7 +1073,7 @@ def test_extractall_multi_single_nocase(dtype): assert_equal(res_re, expected) -def test_extractall_multi_multi_case(dtype): +def test_extractall_multi_multi_case(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") pat_re = re.compile(pat_re) @@ -1119,7 +1121,7 @@ def test_extractall_multi_multi_case(dtype): assert_equal(res_str_case, expected) -def test_extractall_multi_multi_nocase(dtype): +def test_extractall_multi_multi_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re = pat_str if dtype == np.unicode_ else bytes(pat_str, encoding="UTF-8") pat_re = re.compile(pat_re, flags=re.I) @@ -1164,7 +1166,7 @@ def test_extractall_multi_multi_nocase(dtype): assert_equal(res_re, expected) -def test_extractall_broadcast(dtype): +def test_extractall_broadcast(dtype) -> None: value = xr.DataArray( ["a_Xy_0", "ab_xY_10", "abc_Xy_01"], dims=["X"], @@ -1193,7 +1195,7 @@ def test_extractall_broadcast(dtype): assert_equal(res_re, expected) -def test_findall_single_single_case(dtype): +def test_findall_single_single_case(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re = re.compile(dtype(pat_str)) @@ -1220,7 +1222,7 @@ def test_findall_single_single_case(dtype): assert_equal(res_str_case, expected) -def test_findall_single_single_nocase(dtype): +def test_findall_single_single_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re = re.compile(dtype(pat_str), flags=re.I) @@ -1244,7 +1246,7 @@ def test_findall_single_single_nocase(dtype): assert_equal(res_re, expected) -def test_findall_single_multi_case(dtype): +def test_findall_single_multi_case(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re = re.compile(dtype(pat_str)) @@ -1285,7 +1287,7 @@ def test_findall_single_multi_case(dtype): assert_equal(res_str_case, expected) -def test_findall_single_multi_nocase(dtype): +def test_findall_single_multi_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_\d*" pat_re = re.compile(dtype(pat_str), flags=re.I) @@ -1327,7 +1329,7 @@ def test_findall_single_multi_nocase(dtype): assert_equal(res_re, expected) -def test_findall_multi_single_case(dtype): +def test_findall_multi_single_case(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re = re.compile(dtype(pat_str)) @@ -1357,7 +1359,7 @@ def test_findall_multi_single_case(dtype): assert_equal(res_str_case, expected) -def test_findall_multi_single_nocase(dtype): +def test_findall_multi_single_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re = re.compile(dtype(pat_str), flags=re.I) @@ -1384,7 +1386,7 @@ def test_findall_multi_single_nocase(dtype): assert_equal(res_re, expected) -def test_findall_multi_multi_case(dtype): +def test_findall_multi_multi_case(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re = re.compile(dtype(pat_str)) @@ -1429,7 +1431,7 @@ def test_findall_multi_multi_case(dtype): assert_equal(res_str_case, expected) -def test_findall_multi_multi_nocase(dtype): +def test_findall_multi_multi_nocase(dtype) -> None: pat_str = r"(\w+)_Xy_(\d*)" pat_re = re.compile(dtype(pat_str), flags=re.I) @@ -1471,7 +1473,7 @@ def test_findall_multi_multi_nocase(dtype): assert_equal(res_re, expected) -def test_findall_broadcast(dtype): +def test_findall_broadcast(dtype) -> None: value = xr.DataArray( ["a_Xy_0", "ab_xY_10", "abc_Xy_01"], dims=["X"], @@ -1498,7 +1500,7 @@ def test_findall_broadcast(dtype): assert_equal(res_re, expected) -def test_repeat(dtype): +def test_repeat(dtype) -> None: values = xr.DataArray(["a", "b", "c", "d"]).astype(dtype) result = values.str.repeat(3) @@ -1513,7 +1515,7 @@ def test_repeat(dtype): assert_equal(result, expected) -def test_repeat_broadcast(dtype): +def test_repeat_broadcast(dtype) -> None: values = xr.DataArray(["a", "b", "c", "d"], dims=["X"]).astype(dtype) reps = xr.DataArray([3, 4], dims=["Y"]) @@ -1532,7 +1534,7 @@ def test_repeat_broadcast(dtype): assert_equal(result, expected) -def test_match(dtype): +def test_match(dtype) -> None: values = xr.DataArray(["fooBAD__barBAD", "foo"]).astype(dtype) # New match behavior introduced in 0.13 @@ -1566,7 +1568,7 @@ def test_match(dtype): assert_equal(result, expected) -def test_empty_str_methods(): +def test_empty_str_methods() -> None: empty = xr.DataArray(np.empty(shape=(0,), dtype="U")) empty_str = empty empty_int = xr.DataArray(np.empty(shape=(0,), dtype=int)) @@ -1652,7 +1654,7 @@ def test_empty_str_methods(): assert_equal(empty_str, empty.str.translate(table)) -def test_ismethods(dtype): +def test_ismethods(dtype) -> None: values = ["A", "b", "Xy", "4", "3A", "", "TT", "55", "-", " "] exp_alnum = [True, True, True, True, True, False, True, True, False, False] @@ -1698,7 +1700,7 @@ def test_ismethods(dtype): assert_equal(res_upper, exp_upper) -def test_isnumeric(): +def test_isnumeric() -> None: # 0x00bc: ¼ VULGAR FRACTION ONE QUARTER # 0x2605: ★ not number # 0x1378: ፸ ETHIOPIC NUMBER SEVENTY @@ -1721,7 +1723,7 @@ def test_isnumeric(): assert_equal(res_decimal, exp_decimal) -def test_len(dtype): +def test_len(dtype) -> None: values = ["foo", "fooo", "fooooo", "fooooooo"] result = xr.DataArray(values).astype(dtype).str.len() expected = xr.DataArray([len(x) for x in values]) @@ -1729,7 +1731,7 @@ def test_len(dtype): assert_equal(result, expected) -def test_find(dtype): +def test_find(dtype) -> None: values = xr.DataArray(["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXX"]) values = values.astype(dtype) @@ -1812,7 +1814,7 @@ def test_find(dtype): assert_equal(result_1, expected_1) -def test_find_broadcast(dtype): +def test_find_broadcast(dtype) -> None: values = xr.DataArray( ["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF", "XXX"], dims=["X"] ) @@ -1858,7 +1860,7 @@ def test_find_broadcast(dtype): assert_equal(result_1, expected) -def test_index(dtype): +def test_index(dtype) -> None: s = xr.DataArray(["ABCDEFG", "BCDEFEF", "DEFGHIJEF", "EFGHEF"]).astype(dtype) result_0 = s.str.index("EF") @@ -1914,7 +1916,7 @@ def test_index(dtype): s.str.index("DE") -def test_index_broadcast(dtype): +def test_index_broadcast(dtype) -> None: values = xr.DataArray( ["ABCDEFGEFDBCA", "BCDEFEFEFDBC", "DEFBCGHIEFBC", "EFGHBCEFBCBCBCEF"], dims=["X"], @@ -1949,7 +1951,7 @@ def test_index_broadcast(dtype): assert_equal(result_1, expected) -def test_translate(): +def test_translate() -> None: values = xr.DataArray(["abcdefg", "abcc", "cdddfg", "cdefggg"]) table = str.maketrans("abc", "cde") result = values.str.translate(table) @@ -1958,7 +1960,7 @@ def test_translate(): assert_equal(result, expected) -def test_pad_center_ljust_rjust(dtype): +def test_pad_center_ljust_rjust(dtype) -> None: values = xr.DataArray(["a", "b", "c", "eeeee"]).astype(dtype) result = values.str.center(5) @@ -1986,7 +1988,7 @@ def test_pad_center_ljust_rjust(dtype): assert_equal(result, expected) -def test_pad_center_ljust_rjust_fillchar(dtype): +def test_pad_center_ljust_rjust_fillchar(dtype) -> None: values = xr.DataArray(["a", "bb", "cccc", "ddddd", "eeeeee"]).astype(dtype) result = values.str.center(5, fillchar="X") @@ -2037,7 +2039,7 @@ def test_pad_center_ljust_rjust_fillchar(dtype): values.str.pad(5, fillchar="XY") -def test_pad_center_ljust_rjust_broadcast(dtype): +def test_pad_center_ljust_rjust_broadcast(dtype) -> None: values = xr.DataArray(["a", "bb", "cccc", "ddddd", "eeeeee"], dims="X").astype( dtype ) @@ -2096,7 +2098,7 @@ def test_pad_center_ljust_rjust_broadcast(dtype): assert_equal(result, expected) -def test_zfill(dtype): +def test_zfill(dtype) -> None: values = xr.DataArray(["1", "22", "aaa", "333", "45678"]).astype(dtype) result = values.str.zfill(5) @@ -2110,7 +2112,7 @@ def test_zfill(dtype): assert_equal(result, expected) -def test_zfill_broadcast(dtype): +def test_zfill_broadcast(dtype) -> None: values = xr.DataArray(["1", "22", "aaa", "333", "45678"]).astype(dtype) width = np.array([4, 5, 0, 3, 8]) @@ -2120,7 +2122,7 @@ def test_zfill_broadcast(dtype): assert_equal(result, expected) -def test_slice(dtype): +def test_slice(dtype) -> None: arr = xr.DataArray(["aafootwo", "aabartwo", "aabazqux"]).astype(dtype) result = arr.str.slice(2, 5) @@ -2138,7 +2140,7 @@ def test_slice(dtype): raise -def test_slice_broadcast(dtype): +def test_slice_broadcast(dtype) -> None: arr = xr.DataArray(["aafootwo", "aabartwo", "aabazqux"]).astype(dtype) start = xr.DataArray([1, 2, 3]) stop = 5 @@ -2149,7 +2151,7 @@ def test_slice_broadcast(dtype): assert_equal(result, exp) -def test_slice_replace(dtype): +def test_slice_replace(dtype) -> None: da = lambda x: xr.DataArray(x).astype(dtype) values = da(["short", "a bit longer", "evenlongerthanthat", ""]) @@ -2194,7 +2196,7 @@ def test_slice_replace(dtype): assert_equal(result, expected) -def test_slice_replace_broadcast(dtype): +def test_slice_replace_broadcast(dtype) -> None: values = xr.DataArray(["short", "a bit longer", "evenlongerthanthat", ""]).astype( dtype ) @@ -2210,7 +2212,7 @@ def test_slice_replace_broadcast(dtype): assert_equal(result, expected) -def test_strip_lstrip_rstrip(dtype): +def test_strip_lstrip_rstrip(dtype) -> None: values = xr.DataArray([" aa ", " bb \n", "cc "]).astype(dtype) result = values.str.strip() @@ -2229,7 +2231,7 @@ def test_strip_lstrip_rstrip(dtype): assert_equal(result, expected) -def test_strip_lstrip_rstrip_args(dtype): +def test_strip_lstrip_rstrip_args(dtype) -> None: values = xr.DataArray(["xxABCxx", "xx BNSD", "LDFJH xx"]).astype(dtype) result = values.str.strip("x") @@ -2248,7 +2250,7 @@ def test_strip_lstrip_rstrip_args(dtype): assert_equal(result, expected) -def test_strip_lstrip_rstrip_broadcast(dtype): +def test_strip_lstrip_rstrip_broadcast(dtype) -> None: values = xr.DataArray(["xxABCxx", "yy BNSD", "LDFJH zz"]).astype(dtype) to_strip = xr.DataArray(["x", "y", "z"]).astype(dtype) @@ -2268,7 +2270,7 @@ def test_strip_lstrip_rstrip_broadcast(dtype): assert_equal(result, expected) -def test_wrap(): +def test_wrap() -> None: # test values are: two words less than width, two words equal to width, # two words greater than width, one word less than width, one word # equal to width, one word greater than width, multiple tokens with @@ -2315,7 +2317,7 @@ def test_wrap(): assert_equal(result, expected) -def test_wrap_kwargs_passed(): +def test_wrap_kwargs_passed() -> None: # GH4334 values = xr.DataArray(" hello world ") @@ -2331,7 +2333,7 @@ def test_wrap_kwargs_passed(): assert_equal(result, expected) -def test_get(dtype): +def test_get(dtype) -> None: values = xr.DataArray(["a_b_c", "c_d_e", "f_g_h"]).astype(dtype) result = values.str[2] @@ -2355,7 +2357,7 @@ def test_get(dtype): assert_equal(result, expected) -def test_get_default(dtype): +def test_get_default(dtype) -> None: # GH4334 values = xr.DataArray(["a_b", "c", ""]).astype(dtype) @@ -2365,7 +2367,7 @@ def test_get_default(dtype): assert_equal(result, expected) -def test_get_broadcast(dtype): +def test_get_broadcast(dtype) -> None: values = xr.DataArray(["a_b_c", "c_d_e", "f_g_h"], dims=["X"]).astype(dtype) inds = xr.DataArray([0, 2], dims=["Y"]) @@ -2377,7 +2379,7 @@ def test_get_broadcast(dtype): assert_equal(result, expected) -def test_encode_decode(): +def test_encode_decode() -> None: data = xr.DataArray(["a", "b", "a\xe4"]) encoded = data.str.encode("utf-8") decoded = encoded.str.decode("utf-8") @@ -2385,7 +2387,7 @@ def test_encode_decode(): assert_equal(data, decoded) -def test_encode_decode_errors(): +def test_encode_decode_errors() -> None: encodeBase = xr.DataArray(["a", "b", "a\x9d"]) msg = ( @@ -2419,7 +2421,7 @@ def test_encode_decode_errors(): assert_equal(result, expected) -def test_partition_whitespace(dtype): +def test_partition_whitespace(dtype) -> None: values = xr.DataArray( [ ["abc def", "spam eggs swallow", "red_blue"], @@ -2467,7 +2469,7 @@ def test_partition_whitespace(dtype): assert_equal(res_rpart_dim, exp_rpart_dim) -def test_partition_comma(dtype): +def test_partition_comma(dtype) -> None: values = xr.DataArray( [ ["abc, def", "spam, eggs, swallow", "red_blue"], @@ -2515,7 +2517,7 @@ def test_partition_comma(dtype): assert_equal(res_rpart_dim, exp_rpart_dim) -def test_partition_empty(dtype): +def test_partition_empty(dtype) -> None: values = xr.DataArray([], dims=["X"]).astype(dtype) expected = xr.DataArray(np.zeros((0, 0)), dims=["X", "ZZ"]).astype(dtype) @@ -2525,7 +2527,7 @@ def test_partition_empty(dtype): assert_equal(res, expected) -def test_split_whitespace(dtype): +def test_split_whitespace(dtype) -> None: values = xr.DataArray( [ ["abc def", "spam\t\teggs\tswallow", "red_blue"], @@ -2663,7 +2665,7 @@ def test_split_whitespace(dtype): assert_equal(res_rsplit_none_10, exp_rsplit_none_full) -def test_split_comma(dtype): +def test_split_comma(dtype) -> None: values = xr.DataArray( [ ["abc,def", "spam,,eggs,swallow", "red_blue"], @@ -2801,7 +2803,7 @@ def test_split_comma(dtype): assert_equal(res_rsplit_none_10, exp_rsplit_none_full) -def test_splitters_broadcast(dtype): +def test_splitters_broadcast(dtype) -> None: values = xr.DataArray( ["ab cd,de fg", "spam, ,eggs swallow", "red_blue"], dims=["X"], @@ -2865,7 +2867,7 @@ def test_splitters_broadcast(dtype): assert_equal(res_right, expected_right) -def test_split_empty(dtype): +def test_split_empty(dtype) -> None: values = xr.DataArray([], dims=["X"]).astype(dtype) expected = xr.DataArray(np.zeros((0, 0)), dims=["X", "ZZ"]).astype(dtype) @@ -2875,7 +2877,7 @@ def test_split_empty(dtype): assert_equal(res, expected) -def test_get_dummies(dtype): +def test_get_dummies(dtype) -> None: values_line = xr.DataArray( [["a|ab~abc|abc", "ab", "a||abc|abcd"], ["abcd|ab|a", "abc|ab~abc", "|a"]], dims=["X", "Y"], @@ -2919,7 +2921,7 @@ def test_get_dummies(dtype): assert_equal(res_comma, targ_comma) -def test_get_dummies_broadcast(dtype): +def test_get_dummies_broadcast(dtype) -> None: values = xr.DataArray( ["x~x|x~x", "x", "x|x~x", "x~x"], dims=["X"], @@ -2947,7 +2949,7 @@ def test_get_dummies_broadcast(dtype): assert_equal(res, expected) -def test_get_dummies_empty(dtype): +def test_get_dummies_empty(dtype) -> None: values = xr.DataArray([], dims=["X"]).astype(dtype) expected = xr.DataArray(np.zeros((0, 0)), dims=["X", "ZZ"]).astype(dtype) @@ -2957,7 +2959,7 @@ def test_get_dummies_empty(dtype): assert_equal(res, expected) -def test_splitters_empty_str(dtype): +def test_splitters_empty_str(dtype) -> None: values = xr.DataArray( [["", "", ""], ["", "", ""]], dims=["X", "Y"], @@ -3032,7 +3034,7 @@ def test_splitters_empty_str(dtype): assert_equal(res_dummies, targ_split_dim) -def test_cat_str(dtype): +def test_cat_str(dtype) -> None: values_1 = xr.DataArray( [["a", "bb", "cccc"], ["ddddd", "eeee", "fff"]], dims=["X", "Y"], @@ -3078,7 +3080,7 @@ def test_cat_str(dtype): assert_equal(res_comma, targ_comma) -def test_cat_uniform(dtype): +def test_cat_uniform(dtype) -> None: values_1 = xr.DataArray( [["a", "bb", "cccc"], ["ddddd", "eeee", "fff"]], dims=["X", "Y"], @@ -3127,7 +3129,7 @@ def test_cat_uniform(dtype): assert_equal(res_comma, targ_comma) -def test_cat_broadcast_right(dtype): +def test_cat_broadcast_right(dtype) -> None: values_1 = xr.DataArray( [["a", "bb", "cccc"], ["ddddd", "eeee", "fff"]], dims=["X", "Y"], @@ -3176,7 +3178,7 @@ def test_cat_broadcast_right(dtype): assert_equal(res_comma, targ_comma) -def test_cat_broadcast_left(dtype): +def test_cat_broadcast_left(dtype) -> None: values_1 = xr.DataArray( ["a", "bb", "cccc"], dims=["Y"], @@ -3241,7 +3243,7 @@ def test_cat_broadcast_left(dtype): assert_equal(res_comma, targ_comma) -def test_cat_broadcast_both(dtype): +def test_cat_broadcast_both(dtype) -> None: values_1 = xr.DataArray( ["a", "bb", "cccc"], dims=["Y"], @@ -3306,7 +3308,7 @@ def test_cat_broadcast_both(dtype): assert_equal(res_comma, targ_comma) -def test_cat_multi(): +def test_cat_multi() -> None: values_1 = xr.DataArray( ["11111", "4"], dims=["X"], @@ -3350,7 +3352,7 @@ def test_cat_multi(): assert_equal(res, expected) -def test_join_scalar(dtype): +def test_join_scalar(dtype) -> None: values = xr.DataArray("aaa").astype(dtype) targ = xr.DataArray("aaa").astype(dtype) @@ -3365,7 +3367,7 @@ def test_join_scalar(dtype): assert_identical(res_space, targ) -def test_join_vector(dtype): +def test_join_vector(dtype) -> None: values = xr.DataArray( ["a", "bb", "cccc"], dims=["Y"], @@ -3391,7 +3393,7 @@ def test_join_vector(dtype): assert_identical(res_space_y, targ_space) -def test_join_2d(dtype): +def test_join_2d(dtype) -> None: values = xr.DataArray( [["a", "bb", "cccc"], ["ddddd", "eeee", "fff"]], dims=["X", "Y"], @@ -3437,7 +3439,7 @@ def test_join_2d(dtype): values.str.join() -def test_join_broadcast(dtype): +def test_join_broadcast(dtype) -> None: values = xr.DataArray( ["a", "bb", "cccc"], dims=["X"], @@ -3459,7 +3461,7 @@ def test_join_broadcast(dtype): assert_identical(res, expected) -def test_format_scalar(): +def test_format_scalar() -> None: values = xr.DataArray( ["{}.{Y}.{ZZ}", "{},{},{X},{X}", "{X}-{Y}-{ZZ}"], dims=["X"], @@ -3484,7 +3486,7 @@ def test_format_scalar(): assert_equal(res, expected) -def test_format_broadcast(): +def test_format_broadcast() -> None: values = xr.DataArray( ["{}.{Y}.{ZZ}", "{},{},{X},{X}", "{X}-{Y}-{ZZ}"], dims=["X"], @@ -3518,7 +3520,7 @@ def test_format_broadcast(): assert_equal(res, expected) -def test_mod_scalar(): +def test_mod_scalar() -> None: values = xr.DataArray( ["%s.%s.%s", "%s,%s,%s", "%s-%s-%s"], dims=["X"], @@ -3539,7 +3541,7 @@ def test_mod_scalar(): assert_equal(res, expected) -def test_mod_dict(): +def test_mod_dict() -> None: values = xr.DataArray( ["%(a)s.%(a)s.%(b)s", "%(b)s,%(c)s,%(b)s", "%(c)s-%(b)s-%(a)s"], dims=["X"], @@ -3560,7 +3562,7 @@ def test_mod_dict(): assert_equal(res, expected) -def test_mod_broadcast_single(): +def test_mod_broadcast_single() -> None: values = xr.DataArray( ["%s_1", "%s_2", "%s_3"], dims=["X"], @@ -3582,7 +3584,7 @@ def test_mod_broadcast_single(): assert_equal(res, expected) -def test_mod_broadcast_multi(): +def test_mod_broadcast_multi() -> None: values = xr.DataArray( ["%s.%s.%s", "%s,%s,%s", "%s-%s-%s"], dims=["X"], diff --git a/xarray/tests/test_backends_api.py b/xarray/tests/test_backends_api.py index 4124d0d0b81..cd62ebd4239 100644 --- a/xarray/tests/test_backends_api.py +++ b/xarray/tests/test_backends_api.py @@ -8,7 +8,7 @@ @requires_netCDF4 @requires_scipy -def test__get_default_engine(): +def test__get_default_engine() -> None: engine_remote = _get_default_engine("http://example.org/test.nc", allow_remote=True) assert engine_remote == "netcdf4" @@ -19,7 +19,7 @@ def test__get_default_engine(): assert engine_default == "netcdf4" -def test_custom_engine(): +def test_custom_engine() -> None: expected = xr.Dataset( dict(a=2 * np.arange(5)), coords=dict(x=("x", np.arange(5), dict(units="s"))) ) diff --git a/xarray/tests/test_backends_common.py b/xarray/tests/test_backends_common.py index 7f91e644e2a..75729a8f046 100644 --- a/xarray/tests/test_backends_common.py +++ b/xarray/tests/test_backends_common.py @@ -18,7 +18,7 @@ def __getitem__(self, key): return "success" -def test_robust_getitem(): +def test_robust_getitem() -> None: array = DummyArray(failures=2) with pytest.raises(DummyFailure): array[...] diff --git a/xarray/tests/test_backends_file_manager.py b/xarray/tests/test_backends_file_manager.py index 16f059c7bad..6b8c4da01de 100644 --- a/xarray/tests/test_backends_file_manager.py +++ b/xarray/tests/test_backends_file_manager.py @@ -1,6 +1,7 @@ import gc import pickle import threading +from typing import Dict from unittest import mock import pytest @@ -19,7 +20,7 @@ def file_cache(request): yield LRUCache(maxsize) -def test_file_manager_mock_write(file_cache): +def test_file_manager_mock_write(file_cache) -> None: mock_file = mock.Mock() opener = mock.Mock(spec=open, return_value=mock_file) lock = mock.MagicMock(spec=threading.Lock()) @@ -37,10 +38,10 @@ def test_file_manager_mock_write(file_cache): @pytest.mark.parametrize("expected_warning", [None, RuntimeWarning]) -def test_file_manager_autoclose(expected_warning): +def test_file_manager_autoclose(expected_warning) -> None: mock_file = mock.Mock() opener = mock.Mock(return_value=mock_file) - cache = {} + cache: Dict = {} manager = CachingFileManager(opener, "filename", cache=cache) manager.acquire() @@ -55,10 +56,10 @@ def test_file_manager_autoclose(expected_warning): mock_file.close.assert_called_once_with() -def test_file_manager_autoclose_while_locked(): +def test_file_manager_autoclose_while_locked() -> None: opener = mock.Mock() lock = threading.Lock() - cache = {} + cache: Dict = {} manager = CachingFileManager(opener, "filename", lock=lock, cache=cache) manager.acquire() @@ -74,17 +75,17 @@ def test_file_manager_autoclose_while_locked(): assert cache -def test_file_manager_repr(): +def test_file_manager_repr() -> None: opener = mock.Mock() manager = CachingFileManager(opener, "my-file") assert "my-file" in repr(manager) -def test_file_manager_refcounts(): +def test_file_manager_refcounts() -> None: mock_file = mock.Mock() opener = mock.Mock(spec=open, return_value=mock_file) - cache = {} - ref_counts = {} + cache: Dict = {} + ref_counts: Dict = {} manager = CachingFileManager(opener, "filename", cache=cache, ref_counts=ref_counts) assert ref_counts[manager._key] == 1 @@ -114,10 +115,10 @@ def test_file_manager_refcounts(): assert not cache -def test_file_manager_replace_object(): +def test_file_manager_replace_object() -> None: opener = mock.Mock() - cache = {} - ref_counts = {} + cache: Dict = {} + ref_counts: Dict = {} manager = CachingFileManager(opener, "filename", cache=cache, ref_counts=ref_counts) manager.acquire() @@ -131,7 +132,7 @@ def test_file_manager_replace_object(): manager.close() -def test_file_manager_write_consecutive(tmpdir, file_cache): +def test_file_manager_write_consecutive(tmpdir, file_cache) -> None: path1 = str(tmpdir.join("testing1.txt")) path2 = str(tmpdir.join("testing2.txt")) manager1 = CachingFileManager(open, path1, mode="w", cache=file_cache) @@ -154,7 +155,7 @@ def test_file_manager_write_consecutive(tmpdir, file_cache): assert f.read() == "bar" -def test_file_manager_write_concurrent(tmpdir, file_cache): +def test_file_manager_write_concurrent(tmpdir, file_cache) -> None: path = str(tmpdir.join("testing.txt")) manager = CachingFileManager(open, path, mode="w", cache=file_cache) f1 = manager.acquire() @@ -174,7 +175,7 @@ def test_file_manager_write_concurrent(tmpdir, file_cache): assert f.read() == "foobarbaz" -def test_file_manager_write_pickle(tmpdir, file_cache): +def test_file_manager_write_pickle(tmpdir, file_cache) -> None: path = str(tmpdir.join("testing.txt")) manager = CachingFileManager(open, path, mode="w", cache=file_cache) f = manager.acquire() @@ -190,7 +191,7 @@ def test_file_manager_write_pickle(tmpdir, file_cache): assert f.read() == "foobar" -def test_file_manager_read(tmpdir, file_cache): +def test_file_manager_read(tmpdir, file_cache) -> None: path = str(tmpdir.join("testing.txt")) with open(path, "w") as f: @@ -202,7 +203,7 @@ def test_file_manager_read(tmpdir, file_cache): manager.close() -def test_file_manager_acquire_context(tmpdir, file_cache): +def test_file_manager_acquire_context(tmpdir, file_cache) -> None: path = str(tmpdir.join("testing.txt")) with open(path, "w") as f: diff --git a/xarray/tests/test_backends_locks.py b/xarray/tests/test_backends_locks.py index f7e48b65d46..0aa5f99f282 100644 --- a/xarray/tests/test_backends_locks.py +++ b/xarray/tests/test_backends_locks.py @@ -3,7 +3,7 @@ from xarray.backends import locks -def test_threaded_lock(): +def test_threaded_lock() -> None: lock1 = locks._get_threaded_lock("foo") assert isinstance(lock1, type(threading.Lock())) lock2 = locks._get_threaded_lock("foo") diff --git a/xarray/tests/test_backends_lru_cache.py b/xarray/tests/test_backends_lru_cache.py index 2aaa8c9e631..2b0c7742e5c 100644 --- a/xarray/tests/test_backends_lru_cache.py +++ b/xarray/tests/test_backends_lru_cache.py @@ -1,3 +1,4 @@ +from typing import Any from unittest import mock import pytest @@ -5,8 +6,8 @@ from xarray.backends.lru_cache import LRUCache -def test_simple(): - cache = LRUCache(maxsize=2) +def test_simple() -> None: + cache: LRUCache[Any, Any] = LRUCache(maxsize=2) cache["x"] = 1 cache["y"] = 2 @@ -22,21 +23,21 @@ def test_simple(): assert list(cache.items()) == [("y", 2), ("z", 3)] -def test_trivial(): - cache = LRUCache(maxsize=0) +def test_trivial() -> None: + cache: LRUCache[Any, Any] = LRUCache(maxsize=0) cache["x"] = 1 assert len(cache) == 0 -def test_invalid(): +def test_invalid() -> None: with pytest.raises(TypeError): - LRUCache(maxsize=None) + LRUCache(maxsize=None) # type: ignore with pytest.raises(ValueError): LRUCache(maxsize=-1) -def test_update_priority(): - cache = LRUCache(maxsize=2) +def test_update_priority() -> None: + cache: LRUCache[Any, Any] = LRUCache(maxsize=2) cache["x"] = 1 cache["y"] = 2 assert list(cache) == ["x", "y"] @@ -48,15 +49,15 @@ def test_update_priority(): assert list(cache.items()) == [("y", 2), ("x", 3)] -def test_del(): - cache = LRUCache(maxsize=2) +def test_del() -> None: + cache: LRUCache[Any, Any] = LRUCache(maxsize=2) cache["x"] = 1 cache["y"] = 2 del cache["x"] assert dict(cache) == {"y": 2} -def test_on_evict(): +def test_on_evict() -> None: on_evict = mock.Mock() cache = LRUCache(maxsize=1, on_evict=on_evict) cache["x"] = 1 @@ -64,15 +65,15 @@ def test_on_evict(): on_evict.assert_called_once_with("x", 1) -def test_on_evict_trivial(): +def test_on_evict_trivial() -> None: on_evict = mock.Mock() cache = LRUCache(maxsize=0, on_evict=on_evict) cache["x"] = 1 on_evict.assert_called_once_with("x", 1) -def test_resize(): - cache = LRUCache(maxsize=2) +def test_resize() -> None: + cache: LRUCache[Any, Any] = LRUCache(maxsize=2) assert cache.maxsize == 2 cache["w"] = 0 cache["x"] = 1 diff --git a/xarray/tests/test_cftimeindex_resample.py b/xarray/tests/test_cftimeindex_resample.py index 526f3fc30c1..af15f997643 100644 --- a/xarray/tests/test_cftimeindex_resample.py +++ b/xarray/tests/test_cftimeindex_resample.py @@ -58,7 +58,7 @@ def da(index): @pytest.mark.parametrize("closed", [None, "left", "right"]) @pytest.mark.parametrize("label", [None, "left", "right"]) @pytest.mark.parametrize("base", [24, 31]) -def test_resample(freqs, closed, label, base): +def test_resample(freqs, closed, label, base) -> None: initial_freq, resample_freq = freqs start = "2000-01-01T12:07:01" index_kwargs = dict(start=start, periods=5, freq=initial_freq) @@ -121,7 +121,7 @@ def test_resample(freqs, closed, label, base): ("AS", "left"), ], ) -def test_closed_label_defaults(freq, expected): +def test_closed_label_defaults(freq, expected) -> None: assert CFTimeGrouper(freq=freq).closed == expected assert CFTimeGrouper(freq=freq).label == expected @@ -130,7 +130,7 @@ def test_closed_label_defaults(freq, expected): @pytest.mark.parametrize( "calendar", ["gregorian", "noleap", "all_leap", "360_day", "julian"] ) -def test_calendars(calendar): +def test_calendars(calendar) -> None: # Limited testing for non-standard calendars freq, closed, label, base = "8001T", None, None, 17 loffset = datetime.timedelta(hours=12) diff --git a/xarray/tests/test_coarsen.py b/xarray/tests/test_coarsen.py index 278a961166f..ef2eeac0e0b 100644 --- a/xarray/tests/test_coarsen.py +++ b/xarray/tests/test_coarsen.py @@ -17,14 +17,14 @@ from .test_dataset import ds -def test_coarsen_absent_dims_error(ds): +def test_coarsen_absent_dims_error(ds) -> None: with pytest.raises(ValueError, match=r"not found in Dataset."): ds.coarsen(foo=2) @pytest.mark.parametrize("dask", [True, False]) @pytest.mark.parametrize(("boundary", "side"), [("trim", "left"), ("pad", "right")]) -def test_coarsen_dataset(ds, dask, boundary, side): +def test_coarsen_dataset(ds, dask, boundary, side) -> None: if dask and has_dask: ds = ds.chunk({"x": 4}) @@ -39,7 +39,7 @@ def test_coarsen_dataset(ds, dask, boundary, side): @pytest.mark.parametrize("dask", [True, False]) -def test_coarsen_coords(ds, dask): +def test_coarsen_coords(ds, dask) -> None: if dask and has_dask: ds = ds.chunk({"x": 4}) @@ -64,7 +64,7 @@ def test_coarsen_coords(ds, dask): @requires_cftime -def test_coarsen_coords_cftime(): +def test_coarsen_coords_cftime() -> None: times = xr.cftime_range("2000", periods=6) da = xr.DataArray(range(6), [("time", times)]) actual = da.coarsen(time=3).mean() @@ -79,7 +79,7 @@ def test_coarsen_coords_cftime(): ("mean", ()), ], ) -def test_coarsen_keep_attrs(funcname, argument): +def test_coarsen_keep_attrs(funcname, argument) -> None: global_attrs = {"units": "test", "long_name": "testing"} da_attrs = {"da_attr": "test"} attrs_coords = {"attrs_coords": "test"} @@ -157,7 +157,7 @@ def test_coarsen_keep_attrs(funcname, argument): @pytest.mark.parametrize("ds", (1, 2), indirect=True) @pytest.mark.parametrize("window", (1, 2, 3, 4)) @pytest.mark.parametrize("name", ("sum", "mean", "std", "var", "min", "max", "median")) -def test_coarsen_reduce(ds, window, name): +def test_coarsen_reduce(ds, window, name) -> None: # Use boundary="trim" to accomodate all window sizes used in tests coarsen_obj = ds.coarsen(time=window, boundary="trim") @@ -181,7 +181,7 @@ def test_coarsen_reduce(ds, window, name): ("mean", ()), ], ) -def test_coarsen_da_keep_attrs(funcname, argument): +def test_coarsen_da_keep_attrs(funcname, argument) -> None: attrs_da = {"da_attr": "test"} attrs_coords = {"attrs_coords": "test"} @@ -237,7 +237,7 @@ def test_coarsen_da_keep_attrs(funcname, argument): @pytest.mark.parametrize("da", (1, 2), indirect=True) @pytest.mark.parametrize("window", (1, 2, 3, 4)) @pytest.mark.parametrize("name", ("sum", "mean", "std", "max")) -def test_coarsen_da_reduce(da, window, name): +def test_coarsen_da_reduce(da, window, name) -> None: if da.isnull().sum() > 1 and window == 1: pytest.skip("These parameters lead to all-NaN slices") @@ -251,7 +251,7 @@ def test_coarsen_da_reduce(da, window, name): @pytest.mark.parametrize("dask", [True, False]) -def test_coarsen_construct(dask): +def test_coarsen_construct(dask) -> None: ds = Dataset( { diff --git a/xarray/tests/test_coding.py b/xarray/tests/test_coding.py index 839f2fd1f2e..1c2e5aa505a 100644 --- a/xarray/tests/test_coding.py +++ b/xarray/tests/test_coding.py @@ -14,7 +14,7 @@ import dask.array as da -def test_CFMaskCoder_decode(): +def test_CFMaskCoder_decode() -> None: original = xr.Variable(("x",), [0, -1, 1], {"_FillValue": -1}) expected = xr.Variable(("x",), [0, np.nan, 1]) coder = variables.CFMaskCoder() @@ -43,7 +43,7 @@ def test_CFMaskCoder_decode(): CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS.values(), ids=list(CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS.keys()), ) -def test_CFMaskCoder_encode_missing_fill_values_conflict(data, encoding): +def test_CFMaskCoder_encode_missing_fill_values_conflict(data, encoding) -> None: original = xr.Variable(("x",), data, encoding=encoding) encoded = encode_cf_variable(original) @@ -55,7 +55,7 @@ def test_CFMaskCoder_encode_missing_fill_values_conflict(data, encoding): assert_identical(roundtripped, original) -def test_CFMaskCoder_missing_value(): +def test_CFMaskCoder_missing_value() -> None: expected = xr.DataArray( np.array([[26915, 27755, -9999, 27705], [25595, -9999, 28315, -9999]]), dims=["npts", "ntimes"], @@ -74,7 +74,7 @@ def test_CFMaskCoder_missing_value(): @requires_dask -def test_CFMaskCoder_decode_dask(): +def test_CFMaskCoder_decode_dask() -> None: original = xr.Variable(("x",), [0, -1, 1], {"_FillValue": -1}).chunk() expected = xr.Variable(("x",), [0, np.nan, 1]) coder = variables.CFMaskCoder() @@ -87,7 +87,7 @@ def test_CFMaskCoder_decode_dask(): # TODO(shoyer): parameterize when we have more coders -def test_coder_roundtrip(): +def test_coder_roundtrip() -> None: original = xr.Variable(("x",), [0.0, np.nan, 1.0]) coder = variables.CFMaskCoder() roundtripped = coder.decode(coder.encode(original)) @@ -95,7 +95,7 @@ def test_coder_roundtrip(): @pytest.mark.parametrize("dtype", "u1 u2 i1 i2 f2 f4".split()) -def test_scaling_converts_to_float32(dtype): +def test_scaling_converts_to_float32(dtype) -> None: original = xr.Variable( ("x",), np.arange(10, dtype=dtype), encoding=dict(scale_factor=10) ) @@ -109,7 +109,7 @@ def test_scaling_converts_to_float32(dtype): @pytest.mark.parametrize("scale_factor", (10, [10])) @pytest.mark.parametrize("add_offset", (0.1, [0.1])) -def test_scaling_offset_as_list(scale_factor, add_offset): +def test_scaling_offset_as_list(scale_factor, add_offset) -> None: # test for #4631 encoding = dict(scale_factor=scale_factor, add_offset=add_offset) original = xr.Variable(("x",), np.arange(10.0), encoding=encoding) @@ -120,7 +120,7 @@ def test_scaling_offset_as_list(scale_factor, add_offset): @pytest.mark.parametrize("bits", [1, 2, 4, 8]) -def test_decode_unsigned_from_signed(bits): +def test_decode_unsigned_from_signed(bits) -> None: unsigned_dtype = np.dtype(f"u{bits}") signed_dtype = np.dtype(f"i{bits}") original_values = np.array([np.iinfo(unsigned_dtype).max], dtype=unsigned_dtype) @@ -134,7 +134,7 @@ def test_decode_unsigned_from_signed(bits): @pytest.mark.parametrize("bits", [1, 2, 4, 8]) -def test_decode_signed_from_unsigned(bits): +def test_decode_signed_from_unsigned(bits) -> None: unsigned_dtype = np.dtype(f"u{bits}") signed_dtype = np.dtype(f"i{bits}") original_values = np.array([-1], dtype=signed_dtype) diff --git a/xarray/tests/test_coding_strings.py b/xarray/tests/test_coding_strings.py index 800e91d9473..e35e31b74ad 100644 --- a/xarray/tests/test_coding_strings.py +++ b/xarray/tests/test_coding_strings.py @@ -13,7 +13,7 @@ import dask.array as da -def test_vlen_dtype(): +def test_vlen_dtype() -> None: dtype = strings.create_vlen_dtype(str) assert dtype.metadata["element_type"] == str assert strings.is_unicode_dtype(dtype) @@ -29,7 +29,7 @@ def test_vlen_dtype(): assert strings.check_vlen_dtype(np.dtype(object)) is None -def test_EncodedStringCoder_decode(): +def test_EncodedStringCoder_decode() -> None: coder = strings.EncodedStringCoder() raw_data = np.array([b"abc", "ß∂µ∆".encode()]) @@ -43,7 +43,7 @@ def test_EncodedStringCoder_decode(): @requires_dask -def test_EncodedStringCoder_decode_dask(): +def test_EncodedStringCoder_decode_dask() -> None: coder = strings.EncodedStringCoder() raw_data = np.array([b"abc", "ß∂µ∆".encode()]) @@ -59,7 +59,7 @@ def test_EncodedStringCoder_decode_dask(): assert_identical(actual_indexed, expected[0]) -def test_EncodedStringCoder_encode(): +def test_EncodedStringCoder_encode() -> None: dtype = strings.create_vlen_dtype(str) raw_data = np.array(["abc", "ß∂µ∆"], dtype=dtype) expected_data = np.array([r.encode("utf-8") for r in raw_data], dtype=object) @@ -86,7 +86,7 @@ def test_EncodedStringCoder_encode(): Variable((), b"a"), ], ) -def test_CharacterArrayCoder_roundtrip(original): +def test_CharacterArrayCoder_roundtrip(original) -> None: coder = strings.CharacterArrayCoder() roundtripped = coder.decode(coder.encode(original)) assert_identical(original, roundtripped) @@ -99,7 +99,7 @@ def test_CharacterArrayCoder_roundtrip(original): np.array([b"a", b"bc"], dtype=strings.create_vlen_dtype(bytes)), ], ) -def test_CharacterArrayCoder_encode(data): +def test_CharacterArrayCoder_encode(data) -> None: coder = strings.CharacterArrayCoder() raw = Variable(("x",), data) actual = coder.encode(raw) @@ -114,7 +114,7 @@ def test_CharacterArrayCoder_encode(data): (Variable(("x",), [b"ab", b"cdef"], encoding={"char_dim_name": "foo"}), "foo"), ], ) -def test_CharacterArrayCoder_char_dim_name(original, expected_char_dim_name): +def test_CharacterArrayCoder_char_dim_name(original, expected_char_dim_name) -> None: coder = strings.CharacterArrayCoder() encoded = coder.encode(original) roundtripped = coder.decode(encoded) @@ -123,7 +123,7 @@ def test_CharacterArrayCoder_char_dim_name(original, expected_char_dim_name): assert roundtripped.dims[-1] == original.dims[-1] -def test_StackedBytesArray(): +def test_StackedBytesArray() -> None: array = np.array([[b"a", b"b", b"c"], [b"d", b"e", b"f"]], dtype="S") actual = strings.StackedBytesArray(array) expected = np.array([b"abc", b"def"], dtype="S") @@ -140,7 +140,7 @@ def test_StackedBytesArray(): actual[B[:, :2]] -def test_StackedBytesArray_scalar(): +def test_StackedBytesArray_scalar() -> None: array = np.array([b"a", b"b", b"c"], dtype="S") actual = strings.StackedBytesArray(array) @@ -158,7 +158,7 @@ def test_StackedBytesArray_scalar(): actual[B[:2]] -def test_StackedBytesArray_vectorized_indexing(): +def test_StackedBytesArray_vectorized_indexing() -> None: array = np.array([[b"a", b"b", b"c"], [b"d", b"e", b"f"]], dtype="S") stacked = strings.StackedBytesArray(array) expected = np.array([[b"abc", b"def"], [b"def", b"abc"]]) @@ -169,7 +169,7 @@ def test_StackedBytesArray_vectorized_indexing(): assert_array_equal(actual, expected) -def test_char_to_bytes(): +def test_char_to_bytes() -> None: array = np.array([[b"a", b"b", b"c"], [b"d", b"e", b"f"]]) expected = np.array([b"abc", b"def"]) actual = strings.char_to_bytes(array) @@ -180,13 +180,13 @@ def test_char_to_bytes(): assert_array_equal(actual, expected) -def test_char_to_bytes_ndim_zero(): +def test_char_to_bytes_ndim_zero() -> None: expected = np.array(b"a") actual = strings.char_to_bytes(expected) assert_array_equal(actual, expected) -def test_char_to_bytes_size_zero(): +def test_char_to_bytes_size_zero() -> None: array = np.zeros((3, 0), dtype="S1") expected = np.array([b"", b"", b""]) actual = strings.char_to_bytes(array) @@ -194,7 +194,7 @@ def test_char_to_bytes_size_zero(): @requires_dask -def test_char_to_bytes_dask(): +def test_char_to_bytes_dask() -> None: numpy_array = np.array([[b"a", b"b", b"c"], [b"d", b"e", b"f"]]) array = da.from_array(numpy_array, ((2,), (3,))) expected = np.array([b"abc", b"def"]) @@ -208,7 +208,7 @@ def test_char_to_bytes_dask(): strings.char_to_bytes(array.rechunk(1)) -def test_bytes_to_char(): +def test_bytes_to_char() -> None: array = np.array([[b"ab", b"cd"], [b"ef", b"gh"]]) expected = np.array([[[b"a", b"b"], [b"c", b"d"]], [[b"e", b"f"], [b"g", b"h"]]]) actual = strings.bytes_to_char(array) @@ -220,7 +220,7 @@ def test_bytes_to_char(): @requires_dask -def test_bytes_to_char_dask(): +def test_bytes_to_char_dask() -> None: numpy_array = np.array([b"ab", b"cd"]) array = da.from_array(numpy_array, ((1, 1),)) expected = np.array([[b"a", b"b"], [b"c", b"d"]]) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index f0882afe367..aff2cb8cf3a 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -109,7 +109,7 @@ def _all_cftime_date_types(): @requires_cftime @pytest.mark.filterwarnings("ignore:Ambiguous reference date string") @pytest.mark.parametrize(["num_dates", "units", "calendar"], _CF_DATETIME_TESTS) -def test_cf_datetime(num_dates, units, calendar): +def test_cf_datetime(num_dates, units, calendar) -> None: import cftime expected = cftime.num2date( @@ -145,7 +145,7 @@ def test_cf_datetime(num_dates, units, calendar): @requires_cftime -def test_decode_cf_datetime_overflow(): +def test_decode_cf_datetime_overflow() -> None: # checks for # https://github.com/pydata/pandas/issues/14068 # https://github.com/pydata/xarray/issues/975 @@ -165,7 +165,7 @@ def test_decode_cf_datetime_overflow(): assert result == expected[i] -def test_decode_cf_datetime_non_standard_units(): +def test_decode_cf_datetime_non_standard_units() -> None: expected = pd.date_range(periods=100, start="1970-01-01", freq="h") # netCDFs from madis.noaa.gov use this format for their time units # they cannot be parsed by cftime, but pd.Timestamp works @@ -175,7 +175,7 @@ def test_decode_cf_datetime_non_standard_units(): @requires_cftime -def test_decode_cf_datetime_non_iso_strings(): +def test_decode_cf_datetime_non_iso_strings() -> None: # datetime strings that are _almost_ ISO compliant but not quite, # but which cftime.num2date can still parse correctly expected = pd.date_range(periods=100, start="2000-01-01", freq="h") @@ -195,7 +195,7 @@ def test_decode_cf_datetime_non_iso_strings(): @requires_cftime @pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) -def test_decode_standard_calendar_inside_timestamp_range(calendar): +def test_decode_standard_calendar_inside_timestamp_range(calendar) -> None: import cftime units = "days since 0001-01-01" @@ -215,7 +215,7 @@ def test_decode_standard_calendar_inside_timestamp_range(calendar): @requires_cftime @pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) -def test_decode_non_standard_calendar_inside_timestamp_range(calendar): +def test_decode_non_standard_calendar_inside_timestamp_range(calendar) -> None: import cftime units = "days since 0001-01-01" @@ -240,7 +240,7 @@ def test_decode_non_standard_calendar_inside_timestamp_range(calendar): @requires_cftime @pytest.mark.parametrize("calendar", _ALL_CALENDARS) -def test_decode_dates_outside_timestamp_range(calendar): +def test_decode_dates_outside_timestamp_range(calendar) -> None: from datetime import datetime import cftime @@ -267,7 +267,9 @@ def test_decode_dates_outside_timestamp_range(calendar): @requires_cftime @pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) -def test_decode_standard_calendar_single_element_inside_timestamp_range(calendar): +def test_decode_standard_calendar_single_element_inside_timestamp_range( + calendar, +) -> None: units = "days since 0001-01-01" for num_time in [735368, [735368], [[735368]]]: with warnings.catch_warnings(): @@ -278,7 +280,9 @@ def test_decode_standard_calendar_single_element_inside_timestamp_range(calendar @requires_cftime @pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) -def test_decode_non_standard_calendar_single_element_inside_timestamp_range(calendar): +def test_decode_non_standard_calendar_single_element_inside_timestamp_range( + calendar, +) -> None: units = "days since 0001-01-01" for num_time in [735368, [735368], [[735368]]]: with warnings.catch_warnings(): @@ -289,7 +293,7 @@ def test_decode_non_standard_calendar_single_element_inside_timestamp_range(cale @requires_cftime @pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) -def test_decode_single_element_outside_timestamp_range(calendar): +def test_decode_single_element_outside_timestamp_range(calendar) -> None: import cftime units = "days since 0001-01-01" @@ -309,7 +313,9 @@ def test_decode_single_element_outside_timestamp_range(calendar): @requires_cftime @pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) -def test_decode_standard_calendar_multidim_time_inside_timestamp_range(calendar): +def test_decode_standard_calendar_multidim_time_inside_timestamp_range( + calendar, +) -> None: import cftime units = "days since 0001-01-01" @@ -338,7 +344,9 @@ def test_decode_standard_calendar_multidim_time_inside_timestamp_range(calendar) @requires_cftime @pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) -def test_decode_nonstandard_calendar_multidim_time_inside_timestamp_range(calendar): +def test_decode_nonstandard_calendar_multidim_time_inside_timestamp_range( + calendar, +) -> None: import cftime units = "days since 0001-01-01" @@ -377,7 +385,7 @@ def test_decode_nonstandard_calendar_multidim_time_inside_timestamp_range(calend @requires_cftime @pytest.mark.parametrize("calendar", _ALL_CALENDARS) -def test_decode_multidim_time_outside_timestamp_range(calendar): +def test_decode_multidim_time_outside_timestamp_range(calendar) -> None: from datetime import datetime import cftime @@ -414,7 +422,7 @@ def test_decode_multidim_time_outside_timestamp_range(calendar): ("calendar", "num_time"), [("360_day", 720058.0), ("all_leap", 732059.0), ("366_day", 732059.0)], ) -def test_decode_non_standard_calendar_single_element(calendar, num_time): +def test_decode_non_standard_calendar_single_element(calendar, num_time) -> None: import cftime units = "days since 0001-01-01" @@ -429,7 +437,7 @@ def test_decode_non_standard_calendar_single_element(calendar, num_time): @requires_cftime -def test_decode_360_day_calendar(): +def test_decode_360_day_calendar() -> None: import cftime calendar = "360_day" @@ -454,7 +462,7 @@ def test_decode_360_day_calendar(): @requires_cftime -def test_decode_abbreviation(): +def test_decode_abbreviation() -> None: """Test making sure we properly fall back to cftime on abbreviated units.""" import cftime @@ -479,7 +487,7 @@ def test_decode_abbreviation(): ), ], ) -def test_cf_datetime_nan(num_dates, units, expected_list): +def test_cf_datetime_nan(num_dates, units, expected_list) -> None: with warnings.catch_warnings(): warnings.filterwarnings("ignore", "All-NaN") actual = coding.times.decode_cf_datetime(num_dates, units) @@ -489,7 +497,7 @@ def test_cf_datetime_nan(num_dates, units, expected_list): @requires_cftime -def test_decoded_cf_datetime_array_2d(): +def test_decoded_cf_datetime_array_2d() -> None: # regression test for GH1229 variable = Variable( ("x", "y"), np.array([[0, 1], [2, 3]]), {"units": "days since 2000-01-01"} @@ -512,7 +520,7 @@ def test_decoded_cf_datetime_array_2d(): @pytest.mark.parametrize(("freq", "units"), FREQUENCIES_TO_ENCODING_UNITS.items()) -def test_infer_datetime_units(freq, units): +def test_infer_datetime_units(freq, units) -> None: dates = pd.date_range("2000", periods=2, freq=freq) expected = f"{units} since 2000-01-01 00:00:00" assert expected == coding.times.infer_datetime_units(dates) @@ -529,7 +537,7 @@ def test_infer_datetime_units(freq, units): (pd.to_datetime(["NaT"]), "days since 1970-01-01 00:00:00"), ], ) -def test_infer_datetime_units_with_NaT(dates, expected): +def test_infer_datetime_units_with_NaT(dates, expected) -> None: assert expected == coding.times.infer_datetime_units(dates) @@ -551,7 +559,7 @@ def test_infer_datetime_units_with_NaT(dates, expected): "calendar", _NON_STANDARD_CALENDARS + ["gregorian", "proleptic_gregorian"] ) @pytest.mark.parametrize(("date_args", "expected"), _CFTIME_DATETIME_UNITS_TESTS) -def test_infer_cftime_datetime_units(calendar, date_args, expected): +def test_infer_cftime_datetime_units(calendar, date_args, expected) -> None: date_type = _all_cftime_date_types()[calendar] dates = [date_type(*args) for args in date_args] assert expected == coding.times.infer_datetime_units(dates) @@ -572,7 +580,7 @@ def test_infer_cftime_datetime_units(calendar, date_args, expected): (["NaT", "NaT"], "days", [np.nan, np.nan]), ], ) -def test_cf_timedelta(timedeltas, units, numbers): +def test_cf_timedelta(timedeltas, units, numbers) -> None: if timedeltas == "NaT": timedeltas = np.timedelta64("NaT", "ns") else: @@ -595,17 +603,16 @@ def test_cf_timedelta(timedeltas, units, numbers): assert_array_equal(expected, actual) -def test_cf_timedelta_2d(): - timedeltas = ["1D", "2D", "3D"] +def test_cf_timedelta_2d() -> None: units = "days" numbers = np.atleast_2d([1, 2, 3]) - timedeltas = np.atleast_2d(to_timedelta_unboxed(timedeltas)) + timedeltas = np.atleast_2d(to_timedelta_unboxed(["1D", "2D", "3D"])) expected = timedeltas actual = coding.times.decode_cf_timedelta(numbers, units) assert_array_equal(expected, actual) - assert expected.dtype == actual.dtype + assert expected.dtype == actual.dtype # type: ignore @pytest.mark.parametrize( @@ -617,7 +624,7 @@ def test_cf_timedelta_2d(): (pd.to_timedelta(["1m3s", "1m4s"]), "seconds"), ], ) -def test_infer_timedelta_units(deltas, expected): +def test_infer_timedelta_units(deltas, expected) -> None: assert expected == coding.times.infer_timedelta_units(deltas) @@ -631,7 +638,7 @@ def test_infer_timedelta_units(deltas, expected): ((1000, 2, 3, 4, 5, 6), "1000-02-03 04:05:06.000000"), ], ) -def test_format_cftime_datetime(date_args, expected): +def test_format_cftime_datetime(date_args, expected) -> None: date_types = _all_cftime_date_types() for date_type in date_types.values(): result = coding.times.format_cftime_datetime(date_type(*date_args)) @@ -639,9 +646,10 @@ def test_format_cftime_datetime(date_args, expected): @pytest.mark.parametrize("calendar", _ALL_CALENDARS) -def test_decode_cf(calendar): +def test_decode_cf(calendar) -> None: days = [1.0, 2.0, 3.0] - da = DataArray(days, coords=[days], dims=["time"], name="test") + # TODO: GH5690 — do we want to allow this type for `coords`? + da = DataArray(days, coords=[days], dims=["time"], name="test") # type: ignore ds = da.to_dataset() for v in ["test", "time"]: @@ -660,7 +668,7 @@ def test_decode_cf(calendar): assert ds.test.dtype == np.dtype("M8[ns]") -def test_decode_cf_time_bounds(): +def test_decode_cf_time_bounds() -> None: da = DataArray( np.arange(6, dtype="int64").reshape((3, 2)), @@ -703,7 +711,7 @@ def test_decode_cf_time_bounds(): @requires_cftime -def test_encode_time_bounds(): +def test_encode_time_bounds() -> None: time = pd.date_range("2000-01-16", periods=1) time_bounds = pd.date_range("2000-01-01", periods=2, freq="MS") @@ -779,41 +787,41 @@ def times_3d(times): @requires_cftime -def test_contains_cftime_datetimes_1d(data): +def test_contains_cftime_datetimes_1d(data) -> None: assert contains_cftime_datetimes(data.time) @requires_cftime @requires_dask -def test_contains_cftime_datetimes_dask_1d(data): +def test_contains_cftime_datetimes_dask_1d(data) -> None: assert contains_cftime_datetimes(data.time.chunk()) @requires_cftime -def test_contains_cftime_datetimes_3d(times_3d): +def test_contains_cftime_datetimes_3d(times_3d) -> None: assert contains_cftime_datetimes(times_3d) @requires_cftime @requires_dask -def test_contains_cftime_datetimes_dask_3d(times_3d): +def test_contains_cftime_datetimes_dask_3d(times_3d) -> None: assert contains_cftime_datetimes(times_3d.chunk()) @pytest.mark.parametrize("non_cftime_data", [DataArray([]), DataArray([1, 2])]) -def test_contains_cftime_datetimes_non_cftimes(non_cftime_data): +def test_contains_cftime_datetimes_non_cftimes(non_cftime_data) -> None: assert not contains_cftime_datetimes(non_cftime_data) @requires_dask @pytest.mark.parametrize("non_cftime_data", [DataArray([]), DataArray([1, 2])]) -def test_contains_cftime_datetimes_non_cftimes_dask(non_cftime_data): +def test_contains_cftime_datetimes_non_cftimes_dask(non_cftime_data) -> None: assert not contains_cftime_datetimes(non_cftime_data.chunk()) @requires_cftime @pytest.mark.parametrize("shape", [(24,), (8, 3), (2, 4, 3)]) -def test_encode_cf_datetime_overflow(shape): +def test_encode_cf_datetime_overflow(shape) -> None: # Test for fix to GH 2272 dates = pd.date_range("2100", periods=24).values.reshape(shape) units = "days since 1800-01-01" @@ -824,7 +832,7 @@ def test_encode_cf_datetime_overflow(shape): np.testing.assert_array_equal(dates, roundtrip) -def test_encode_expected_failures(): +def test_encode_expected_failures() -> None: dates = pd.date_range("2000", periods=3) with pytest.raises(ValueError, match="invalid time units"): @@ -833,7 +841,7 @@ def test_encode_expected_failures(): encode_cf_datetime(dates, units="days since NO_YEAR") -def test_encode_cf_datetime_pandas_min(): +def test_encode_cf_datetime_pandas_min() -> None: # GH 2623 dates = pd.date_range("2000", periods=3) num, units, calendar = encode_cf_datetime(dates) @@ -846,7 +854,7 @@ def test_encode_cf_datetime_pandas_min(): @requires_cftime -def test_time_units_with_timezone_roundtrip(calendar): +def test_time_units_with_timezone_roundtrip(calendar) -> None: # Regression test for GH 2649 expected_units = "days since 2000-01-01T00:00:00-05:00" expected_num_dates = np.array([1, 2, 3]) @@ -874,7 +882,7 @@ def test_time_units_with_timezone_roundtrip(calendar): @pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) -def test_use_cftime_default_standard_calendar_in_range(calendar): +def test_use_cftime_default_standard_calendar_in_range(calendar) -> None: numerical_dates = [0, 1] units = "days since 2000-01-01" expected = pd.date_range("2000", periods=2) @@ -888,7 +896,9 @@ def test_use_cftime_default_standard_calendar_in_range(calendar): @requires_cftime @pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) @pytest.mark.parametrize("units_year", [1500, 2500]) -def test_use_cftime_default_standard_calendar_out_of_range(calendar, units_year): +def test_use_cftime_default_standard_calendar_out_of_range( + calendar, units_year +) -> None: from cftime import num2date numerical_dates = [0, 1] @@ -905,7 +915,7 @@ def test_use_cftime_default_standard_calendar_out_of_range(calendar, units_year) @requires_cftime @pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) @pytest.mark.parametrize("units_year", [1500, 2000, 2500]) -def test_use_cftime_default_non_standard_calendar(calendar, units_year): +def test_use_cftime_default_non_standard_calendar(calendar, units_year) -> None: from cftime import num2date numerical_dates = [0, 1] @@ -923,7 +933,7 @@ def test_use_cftime_default_non_standard_calendar(calendar, units_year): @requires_cftime @pytest.mark.parametrize("calendar", _ALL_CALENDARS) @pytest.mark.parametrize("units_year", [1500, 2000, 2500]) -def test_use_cftime_true(calendar, units_year): +def test_use_cftime_true(calendar, units_year) -> None: from cftime import num2date numerical_dates = [0, 1] @@ -939,7 +949,7 @@ def test_use_cftime_true(calendar, units_year): @pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) -def test_use_cftime_false_standard_calendar_in_range(calendar): +def test_use_cftime_false_standard_calendar_in_range(calendar) -> None: numerical_dates = [0, 1] units = "days since 2000-01-01" expected = pd.date_range("2000", periods=2) @@ -952,7 +962,7 @@ def test_use_cftime_false_standard_calendar_in_range(calendar): @pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) @pytest.mark.parametrize("units_year", [1500, 2500]) -def test_use_cftime_false_standard_calendar_out_of_range(calendar, units_year): +def test_use_cftime_false_standard_calendar_out_of_range(calendar, units_year) -> None: numerical_dates = [0, 1] units = f"days since {units_year}-01-01" with pytest.raises(OutOfBoundsDatetime): @@ -961,7 +971,7 @@ def test_use_cftime_false_standard_calendar_out_of_range(calendar, units_year): @pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) @pytest.mark.parametrize("units_year", [1500, 2000, 2500]) -def test_use_cftime_false_non_standard_calendar(calendar, units_year): +def test_use_cftime_false_non_standard_calendar(calendar, units_year) -> None: numerical_dates = [0, 1] units = f"days since {units_year}-01-01" with pytest.raises(OutOfBoundsDatetime): @@ -970,7 +980,7 @@ def test_use_cftime_false_non_standard_calendar(calendar, units_year): @requires_cftime @pytest.mark.parametrize("calendar", _ALL_CALENDARS) -def test_decode_ambiguous_time_warns(calendar): +def test_decode_ambiguous_time_warns(calendar) -> None: # GH 4422, 4506 from cftime import num2date @@ -1003,7 +1013,9 @@ def test_decode_ambiguous_time_warns(calendar): @pytest.mark.parametrize("encoding_units", FREQUENCIES_TO_ENCODING_UNITS.values()) @pytest.mark.parametrize("freq", FREQUENCIES_TO_ENCODING_UNITS.keys()) @pytest.mark.parametrize("date_range", [pd.date_range, cftime_range]) -def test_encode_cf_datetime_defaults_to_correct_dtype(encoding_units, freq, date_range): +def test_encode_cf_datetime_defaults_to_correct_dtype( + encoding_units, freq, date_range +) -> None: if not has_cftime_1_4_1 and date_range == cftime_range: pytest.skip("Test requires cftime 1.4.1.") if (freq == "N" or encoding_units == "nanoseconds") and date_range == cftime_range: @@ -1021,7 +1033,7 @@ def test_encode_cf_datetime_defaults_to_correct_dtype(encoding_units, freq, date @pytest.mark.parametrize("freq", FREQUENCIES_TO_ENCODING_UNITS.keys()) -def test_encode_decode_roundtrip_datetime64(freq): +def test_encode_decode_roundtrip_datetime64(freq) -> None: # See GH 4045. Prior to GH 4684 this test would fail for frequencies of # "S", "L", "U", and "N". initial_time = pd.date_range("1678-01-01", periods=1) @@ -1034,7 +1046,7 @@ def test_encode_decode_roundtrip_datetime64(freq): @requires_cftime_1_4_1 @pytest.mark.parametrize("freq", ["U", "L", "S", "T", "H", "D"]) -def test_encode_decode_roundtrip_cftime(freq): +def test_encode_decode_roundtrip_cftime(freq) -> None: initial_time = cftime_range("0001", periods=1) times = initial_time.append( cftime_range("0001", periods=2, freq=freq) + timedelta(days=291000 * 365) @@ -1046,7 +1058,7 @@ def test_encode_decode_roundtrip_cftime(freq): @requires_cftime -def test__encode_datetime_with_cftime(): +def test__encode_datetime_with_cftime() -> None: # See GH 4870. cftime versions > 1.4.0 required us to adapt the # way _encode_datetime_with_cftime was written. import cftime @@ -1061,7 +1073,7 @@ def test__encode_datetime_with_cftime(): @pytest.mark.parametrize("calendar", ["gregorian", "Gregorian", "GREGORIAN"]) -def test_decode_encode_roundtrip_with_non_lowercase_letters(calendar): +def test_decode_encode_roundtrip_with_non_lowercase_letters(calendar) -> None: # See GH 5093. times = [0, 1] units = "days since 2000-01-01" diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 2439ea30b4b..22a3efce999 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -38,7 +38,7 @@ def assert_identical(a, b): assert_array_equal(a, b) -def test_signature_properties(): +def test_signature_properties() -> None: sig = _UFuncSignature([["x"], ["x", "y"]], [["z"]]) assert sig.input_core_dims == (("x",), ("x", "y")) assert sig.output_core_dims == (("z",),) @@ -55,7 +55,7 @@ def test_signature_properties(): assert _UFuncSignature([["x"]]) != _UFuncSignature([["y"]]) -def test_result_name(): +def test_result_name() -> None: class Named: def __init__(self, name=None): self.name = name @@ -67,20 +67,20 @@ def __init__(self, name=None): assert result_name([Named("foo"), Named()]) is None -def test_ordered_set_union(): +def test_ordered_set_union() -> None: assert list(ordered_set_union([[1, 2]])) == [1, 2] assert list(ordered_set_union([[1, 2], [2, 1]])) == [1, 2] assert list(ordered_set_union([[0], [1, 2], [1, 3]])) == [0, 1, 2, 3] -def test_ordered_set_intersection(): +def test_ordered_set_intersection() -> None: assert list(ordered_set_intersection([[1, 2]])) == [1, 2] assert list(ordered_set_intersection([[1, 2], [2, 1]])) == [1, 2] assert list(ordered_set_intersection([[1, 2], [1, 3]])) == [1] assert list(ordered_set_intersection([[1, 2], [2]])) == [2] -def test_join_dict_keys(): +def test_join_dict_keys() -> None: dicts = [dict.fromkeys(keys) for keys in [["x", "y"], ["y", "z"]]] assert list(join_dict_keys(dicts, "left")) == ["x", "y"] assert list(join_dict_keys(dicts, "right")) == ["y", "z"] @@ -92,7 +92,7 @@ def test_join_dict_keys(): join_dict_keys(dicts, "foobar") -def test_collect_dict_values(): +def test_collect_dict_values() -> None: dicts = [{"x": 1, "y": 2, "z": 3}, {"z": 4}, 5] expected = [[1, 0, 5], [2, 0, 5], [3, 4, 5]] collected = collect_dict_values(dicts, ["x", "y", "z"], fill_value=0) @@ -103,7 +103,7 @@ def identity(x): return x -def test_apply_identity(): +def test_apply_identity() -> None: array = np.arange(10) variable = xr.Variable("x", array) data_array = xr.DataArray(variable, [("x", -array)]) @@ -123,7 +123,7 @@ def add(a, b): return apply_ufunc(operator.add, a, b) -def test_apply_two_inputs(): +def test_apply_two_inputs() -> None: array = np.array([1, 2, 3]) variable = xr.Variable("x", array) data_array = xr.DataArray(variable, [("x", -array)]) @@ -170,7 +170,7 @@ def test_apply_two_inputs(): assert_identical(dataset, add(zero_dataset, dataset.groupby("x"))) -def test_apply_1d_and_0d(): +def test_apply_1d_and_0d() -> None: array = np.array([1, 2, 3]) variable = xr.Variable("x", array) data_array = xr.DataArray(variable, [("x", -array)]) @@ -217,7 +217,7 @@ def test_apply_1d_and_0d(): assert_identical(dataset, add(zero_dataset, dataset.groupby("x"))) -def test_apply_two_outputs(): +def test_apply_two_outputs() -> None: array = np.arange(5) variable = xr.Variable("x", array) data_array = xr.DataArray(variable, [("x", -array)]) @@ -255,7 +255,7 @@ def func(x): @requires_dask -def test_apply_dask_parallelized_two_outputs(): +def test_apply_dask_parallelized_two_outputs() -> None: data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) def twice(obj): @@ -269,7 +269,7 @@ def func(x): assert_identical(data_array, out1) -def test_apply_input_core_dimension(): +def test_apply_input_core_dimension() -> None: def first_element(obj, dim): def func(x): return x[..., 0] @@ -329,7 +329,7 @@ def multiply(*args): assert_identical(expected, actual) -def test_apply_output_core_dimension(): +def test_apply_output_core_dimension() -> None: def stack_negative(obj): def func(x): return np.stack([x, -x], axis=-1) @@ -391,7 +391,7 @@ def func(x): assert_identical(stacked_dataset, out1) -def test_apply_exclude(): +def test_apply_exclude() -> None: def concatenate(objects, dim="x"): def func(*x): return np.concatenate(x, axis=-1) @@ -432,7 +432,7 @@ def func(*x): apply_ufunc(identity, variables[0], exclude_dims={"x"}) -def test_apply_groupby_add(): +def test_apply_groupby_add() -> None: array = np.arange(5) variable = xr.Variable("x", array) coords = {"x": -array, "y": ("x", [0, 0, 1, 1, 2])} @@ -469,7 +469,7 @@ def test_apply_groupby_add(): add(data_array.groupby("y"), data_array.groupby("x")) -def test_unified_dim_sizes(): +def test_unified_dim_sizes() -> None: assert unified_dim_sizes([xr.Variable((), 0)]) == {} assert unified_dim_sizes([xr.Variable("x", [1]), xr.Variable("x", [1])]) == {"x": 1} assert unified_dim_sizes([xr.Variable("x", [1]), xr.Variable("y", [1, 2])]) == { @@ -493,7 +493,7 @@ def test_unified_dim_sizes(): unified_dim_sizes([xr.Variable("x", [1]), xr.Variable("x", [1, 2])]) -def test_broadcast_compat_data_1d(): +def test_broadcast_compat_data_1d() -> None: data = np.arange(5) var = xr.Variable("x", data) @@ -509,7 +509,7 @@ def test_broadcast_compat_data_1d(): broadcast_compat_data(var, (), ()) -def test_broadcast_compat_data_2d(): +def test_broadcast_compat_data_2d() -> None: data = np.arange(12).reshape(3, 4) var = xr.Variable(["x", "y"], data) @@ -529,7 +529,7 @@ def test_broadcast_compat_data_2d(): ) -def test_keep_attrs(): +def test_keep_attrs() -> None: def add(a, b, keep_attrs): if keep_attrs: return apply_ufunc(operator.add, a, b, keep_attrs=keep_attrs) @@ -552,16 +552,16 @@ def add(a, b, keep_attrs): actual = add(a.variable, b.variable, keep_attrs=True) assert_identical(actual.attrs, a.attrs) - a = xr.Dataset({"x": [0, 1]}) - a.attrs["attr"] = "ds" - a.x.attrs["attr"] = "da" - b = xr.Dataset({"x": [0, 1]}) + ds_a = xr.Dataset({"x": [0, 1]}) + ds_a.attrs["attr"] = "ds" + ds_a.x.attrs["attr"] = "da" + ds_b = xr.Dataset({"x": [0, 1]}) - actual = add(a, b, keep_attrs=False) + actual = add(ds_a, ds_b, keep_attrs=False) assert not actual.attrs - actual = add(a, b, keep_attrs=True) - assert_identical(actual.attrs, a.attrs) - assert_identical(actual.x.attrs, a.x.attrs) + actual = add(ds_a, ds_b, keep_attrs=True) + assert_identical(actual.attrs, ds_a.attrs) + assert_identical(actual.x.attrs, ds_a.x.attrs) @pytest.mark.parametrize( @@ -618,7 +618,7 @@ def add(a, b, keep_attrs): ), ), ) -def test_keep_attrs_strategies_variable(strategy, attrs, expected, error): +def test_keep_attrs_strategies_variable(strategy, attrs, expected, error) -> None: a = xr.Variable("x", [0, 1], attrs=attrs[0]) b = xr.Variable("x", [0, 1], attrs=attrs[1]) c = xr.Variable("x", [0, 1], attrs=attrs[2]) @@ -687,7 +687,7 @@ def test_keep_attrs_strategies_variable(strategy, attrs, expected, error): ), ), ) -def test_keep_attrs_strategies_dataarray(strategy, attrs, expected, error): +def test_keep_attrs_strategies_dataarray(strategy, attrs, expected, error) -> None: a = xr.DataArray(dims="x", data=[0, 1], attrs=attrs[0]) b = xr.DataArray(dims="x", data=[0, 1], attrs=attrs[1]) c = xr.DataArray(dims="x", data=[0, 1], attrs=attrs[2]) @@ -852,7 +852,7 @@ def test_keep_attrs_strategies_dataarray_variables( ), ), ) -def test_keep_attrs_strategies_dataset(strategy, attrs, expected, error): +def test_keep_attrs_strategies_dataset(strategy, attrs, expected, error) -> None: a = xr.Dataset({"a": ("x", [0, 1])}, attrs=attrs[0]) b = xr.Dataset({"a": ("x", [0, 1])}, attrs=attrs[1]) c = xr.Dataset({"a": ("x", [0, 1])}, attrs=attrs[2]) @@ -959,7 +959,7 @@ def test_keep_attrs_strategies_dataset_variables( assert_identical(actual, expected) -def test_dataset_join(): +def test_dataset_join() -> None: ds0 = xr.Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) ds1 = xr.Dataset({"a": ("x", [99, 3]), "x": [1, 2]}) @@ -1007,7 +1007,7 @@ def add(a, b, join, dataset_join): @requires_dask -def test_apply_dask(): +def test_apply_dask() -> None: import dask.array as da array = da.ones((2,), chunks=2) @@ -1049,7 +1049,7 @@ def dask_safe_identity(x): @requires_dask -def test_apply_dask_parallelized_one_arg(): +def test_apply_dask_parallelized_one_arg() -> None: import dask.array as da array = da.ones((2, 2), chunks=(1, 1)) @@ -1069,7 +1069,7 @@ def parallel_identity(x): @requires_dask -def test_apply_dask_parallelized_two_args(): +def test_apply_dask_parallelized_two_args() -> None: import dask.array as da array = da.ones((2, 2), chunks=(1, 1), dtype=np.int64) @@ -1097,7 +1097,7 @@ def check(x, y): @requires_dask -def test_apply_dask_parallelized_errors(): +def test_apply_dask_parallelized_errors() -> None: import dask.array as da array = da.ones((2, 2), chunks=(1, 1)) @@ -1123,7 +1123,7 @@ def test_apply_dask_parallelized_errors(): # https://github.com/dask/dask/issues/3245 @requires_dask @pytest.mark.filterwarnings("ignore:Mean of empty slice") -def test_apply_dask_multiple_inputs(): +def test_apply_dask_multiple_inputs() -> None: import dask.array as da def covariance(x, y): @@ -1166,7 +1166,7 @@ def covariance(x, y): @requires_dask -def test_apply_dask_new_output_dimension(): +def test_apply_dask_new_output_dimension() -> None: import dask.array as da array = da.ones((2, 2), chunks=(1, 1)) @@ -1195,7 +1195,7 @@ def func(x): @requires_dask -def test_apply_dask_new_output_sizes(): +def test_apply_dask_new_output_sizes() -> None: ds = xr.Dataset({"foo": (["lon", "lat"], np.arange(10 * 10).reshape((10, 10)))}) ds["bar"] = ds["foo"] newdims = {"lon_new": 3, "lat_new": 6} @@ -1224,7 +1224,7 @@ def pandas_median(x): return pd.Series(x).median() -def test_vectorize(): +def test_vectorize() -> None: data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) expected = xr.DataArray([1, 2], dims=["x"]) actual = apply_ufunc( @@ -1234,7 +1234,7 @@ def test_vectorize(): @requires_dask -def test_vectorize_dask(): +def test_vectorize_dask() -> None: # run vectorization in dask.array.gufunc by using `dask='parallelized'` data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) expected = xr.DataArray([1, 2], dims=["x"]) @@ -1250,7 +1250,7 @@ def test_vectorize_dask(): @requires_dask -def test_vectorize_dask_dtype(): +def test_vectorize_dask_dtype() -> None: # ensure output_dtypes is preserved with vectorize=True # GH4015 @@ -1290,7 +1290,7 @@ def test_vectorize_dask_dtype(): xr.DataArray([[0 + 0j, 1 + 2j, 2 + 1j]], dims=("x", "y")), ], ) -def test_vectorize_dask_dtype_without_output_dtypes(data_array): +def test_vectorize_dask_dtype_without_output_dtypes(data_array) -> None: # ensure output_dtypes is preserved with vectorize=True # GH4015 @@ -1311,7 +1311,7 @@ def test_vectorize_dask_dtype_without_output_dtypes(data_array): reason="dask/dask#7669: can no longer pass output_dtypes and meta", ) @requires_dask -def test_vectorize_dask_dtype_meta(): +def test_vectorize_dask_dtype_meta() -> None: # meta dtype takes precedence data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) expected = xr.DataArray([1, 2], dims=["x"]) @@ -1335,7 +1335,7 @@ def pandas_median_add(x, y): return pd.Series(x).median() + pd.Series(y).median() -def test_vectorize_exclude_dims(): +def test_vectorize_exclude_dims() -> None: # GH 3890 data_array_a = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) data_array_b = xr.DataArray([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]], dims=("x", "y")) @@ -1353,7 +1353,7 @@ def test_vectorize_exclude_dims(): @requires_dask -def test_vectorize_exclude_dims_dask(): +def test_vectorize_exclude_dims_dask() -> None: # GH 3890 data_array_a = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) data_array_b = xr.DataArray([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]], dims=("x", "y")) @@ -1372,7 +1372,7 @@ def test_vectorize_exclude_dims_dask(): assert_identical(expected, actual) -def test_corr_only_dataarray(): +def test_corr_only_dataarray() -> None: with pytest.raises(TypeError, match="Only xr.DataArray is supported"): xr.corr(xr.Dataset(), xr.Dataset()) @@ -1420,7 +1420,7 @@ def arrays_w_tuples(): ], ) @pytest.mark.parametrize("dim", [None, "x", "time"]) -def test_lazy_corrcov(da_a, da_b, dim, ddof): +def test_lazy_corrcov(da_a, da_b, dim, ddof) -> None: # GH 5284 from dask import is_dask_collection @@ -1438,7 +1438,7 @@ def test_lazy_corrcov(da_a, da_b, dim, ddof): [arrays_w_tuples()[1][0], arrays_w_tuples()[1][1], arrays_w_tuples()[1][2]], ) @pytest.mark.parametrize("dim", [None, "time"]) -def test_cov(da_a, da_b, dim, ddof): +def test_cov(da_a, da_b, dim, ddof) -> None: if dim is not None: def np_cov_ind(ts1, ts2, a, x): @@ -1490,7 +1490,7 @@ def np_cov(ts1, ts2): [arrays_w_tuples()[1][0], arrays_w_tuples()[1][1], arrays_w_tuples()[1][2]], ) @pytest.mark.parametrize("dim", [None, "time"]) -def test_corr(da_a, da_b, dim): +def test_corr(da_a, da_b, dim) -> None: if dim is not None: def np_corr_ind(ts1, ts2, a, x): @@ -1538,7 +1538,7 @@ def np_corr(ts1, ts2): arrays_w_tuples()[1], ) @pytest.mark.parametrize("dim", [None, "time", "x"]) -def test_covcorr_consistency(da_a, da_b, dim): +def test_covcorr_consistency(da_a, da_b, dim) -> None: # Testing that xr.corr and xr.cov are consistent with each other # 1. Broadcast the two arrays da_a, da_b = broadcast(da_a, da_b) @@ -1559,7 +1559,7 @@ def test_covcorr_consistency(da_a, da_b, dim): arrays_w_tuples()[0], ) @pytest.mark.parametrize("dim", [None, "time", "x", ["time", "x"]]) -def test_autocov(da_a, dim): +def test_autocov(da_a, dim) -> None: # Testing that the autocovariance*(N-1) is ~=~ to the variance matrix # 1. Ignore the nans valid_values = da_a.notnull() @@ -1571,7 +1571,7 @@ def test_autocov(da_a, dim): @requires_dask -def test_vectorize_dask_new_output_dims(): +def test_vectorize_dask_new_output_dims() -> None: # regression test for GH3574 # run vectorization in dask.array.gufunc by using `dask='parallelized'` data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) @@ -1614,7 +1614,7 @@ def test_vectorize_dask_new_output_dims(): ) -def test_output_wrong_number(): +def test_output_wrong_number() -> None: variable = xr.Variable("x", np.arange(10)) def identity(x): @@ -1630,7 +1630,7 @@ def tuple3x(x): apply_ufunc(tuple3x, variable, output_core_dims=[(), ()]) -def test_output_wrong_dims(): +def test_output_wrong_dims() -> None: variable = xr.Variable("x", np.arange(10)) def add_dim(x): @@ -1649,7 +1649,7 @@ def remove_dim(x): apply_ufunc(remove_dim, variable) -def test_output_wrong_dim_size(): +def test_output_wrong_dim_size() -> None: array = np.arange(10) variable = xr.Variable("x", array) data_array = xr.DataArray(variable, [("x", -array)]) @@ -1710,7 +1710,7 @@ def apply_truncate_x_x_valid(obj): @pytest.mark.parametrize("use_dask", [True, False]) -def test_dot(use_dask): +def test_dot(use_dask) -> None: if use_dask: if not has_dask: pytest.skip("test for dask.") @@ -1840,7 +1840,7 @@ def test_dot(use_dask): @pytest.mark.parametrize("use_dask", [True, False]) -def test_dot_align_coords(use_dask): +def test_dot_align_coords(use_dask) -> None: # GH 3694 if use_dask: @@ -1893,7 +1893,7 @@ def test_dot_align_coords(use_dask): xr.testing.assert_allclose(expected, actual) -def test_where(): +def test_where() -> None: cond = xr.DataArray([True, False], dims="x") actual = xr.where(cond, 1, 0) expected = xr.DataArray([1, 0], dims="x") @@ -1902,7 +1902,7 @@ def test_where(): @pytest.mark.parametrize("use_dask", [True, False]) @pytest.mark.parametrize("use_datetime", [True, False]) -def test_polyval(use_dask, use_datetime): +def test_polyval(use_dask, use_datetime) -> None: if use_dask and not has_dask: pytest.skip("requires dask") diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index ceea167719f..b364b405423 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -23,7 +23,7 @@ class TestBoolTypeArray: - def test_booltype_array(self): + def test_booltype_array(self) -> None: x = np.array([1, 0, 1, 1, 0], dtype="i1") bx = conventions.BoolTypeArray(x) assert bx.dtype == bool @@ -31,7 +31,7 @@ def test_booltype_array(self): class TestNativeEndiannessArray: - def test(self): + def test(self) -> None: x = np.arange(5, dtype=">i8") expected = np.arange(5, dtype="int64") a = conventions.NativeEndiannessArray(x) @@ -40,7 +40,7 @@ def test(self): assert_array_equal(a, expected) -def test_decode_cf_with_conflicting_fill_missing_value(): +def test_decode_cf_with_conflicting_fill_missing_value() -> None: expected = Variable(["t"], [np.nan, np.nan, 2], {"units": "foobar"}) var = Variable( ["t"], np.arange(3), {"units": "foobar", "missing_value": 0, "_FillValue": 1} @@ -75,7 +75,7 @@ def test_decode_cf_with_conflicting_fill_missing_value(): @requires_cftime class TestEncodeCFVariable: - def test_incompatible_attributes(self): + def test_incompatible_attributes(self) -> None: invalid_vars = [ Variable( ["t"], pd.date_range("2000-01-01", periods=3), {"units": "foobar"} @@ -88,13 +88,13 @@ def test_incompatible_attributes(self): with pytest.raises(ValueError): conventions.encode_cf_variable(var) - def test_missing_fillvalue(self): + def test_missing_fillvalue(self) -> None: v = Variable(["x"], np.array([np.nan, 1, 2, 3])) v.encoding = {"dtype": "int16"} with pytest.warns(Warning, match="floating point data as an integer"): conventions.encode_cf_variable(v) - def test_multidimensional_coordinates(self): + def test_multidimensional_coordinates(self) -> None: # regression test for GH1763 # Set up test case with coordinates that have overlapping (but not # identical) dimensions. @@ -128,7 +128,7 @@ def test_multidimensional_coordinates(self): # Should not have any global coordinates. assert "coordinates" not in attrs - def test_do_not_overwrite_user_coordinates(self): + def test_do_not_overwrite_user_coordinates(self) -> None: orig = Dataset( coords={"x": [0, 1, 2], "y": ("x", [5, 6, 7]), "z": ("x", [8, 9, 10])}, data_vars={"a": ("x", [1, 2, 3]), "b": ("x", [3, 5, 6])}, @@ -142,7 +142,7 @@ def test_do_not_overwrite_user_coordinates(self): with pytest.raises(ValueError, match=r"'coordinates' found in both attrs"): conventions.encode_dataset_coordinates(orig) - def test_emit_coordinates_attribute_in_attrs(self): + def test_emit_coordinates_attribute_in_attrs(self) -> None: orig = Dataset( {"a": 1, "b": 1}, coords={"t": np.array("2004-11-01T00:00:00", dtype=np.datetime64)}, @@ -159,7 +159,7 @@ def test_emit_coordinates_attribute_in_attrs(self): assert enc["b"].attrs.get("coordinates") == "t" assert "coordinates" not in enc["b"].encoding - def test_emit_coordinates_attribute_in_encoding(self): + def test_emit_coordinates_attribute_in_encoding(self) -> None: orig = Dataset( {"a": 1, "b": 1}, coords={"t": np.array("2004-11-01T00:00:00", dtype=np.datetime64)}, @@ -177,7 +177,7 @@ def test_emit_coordinates_attribute_in_encoding(self): assert "coordinates" not in enc["b"].encoding @requires_dask - def test_string_object_warning(self): + def test_string_object_warning(self) -> None: original = Variable(("x",), np.array(["foo", "bar"], dtype=object)).chunk() with pytest.warns(SerializationWarning, match="dask array with dtype=object"): encoded = conventions.encode_cf_variable(original) @@ -186,7 +186,7 @@ def test_string_object_warning(self): @requires_cftime class TestDecodeCF: - def test_dataset(self): + def test_dataset(self) -> None: original = Dataset( { "t": ("t", [0, 1, 2], {"units": "days since 2000-01-01"}), @@ -204,13 +204,13 @@ def test_dataset(self): actual = conventions.decode_cf(original) assert_identical(expected, actual) - def test_invalid_coordinates(self): + def test_invalid_coordinates(self) -> None: # regression test for GH308 original = Dataset({"foo": ("t", [1, 2], {"coordinates": "invalid"})}) actual = conventions.decode_cf(original) assert_identical(original, actual) - def test_decode_coordinates(self): + def test_decode_coordinates(self) -> None: # regression test for GH610 original = Dataset( {"foo": ("t", [1, 2], {"coordinates": "x"}), "x": ("t", [4, 5])} @@ -218,13 +218,13 @@ def test_decode_coordinates(self): actual = conventions.decode_cf(original) assert actual.foo.encoding["coordinates"] == "x" - def test_0d_int32_encoding(self): + def test_0d_int32_encoding(self) -> None: original = Variable((), np.int32(0), encoding={"dtype": "int64"}) expected = Variable((), np.int64(0)) actual = conventions.maybe_encode_nonstring_dtype(original) assert_identical(expected, actual) - def test_decode_cf_with_multiple_missing_values(self): + def test_decode_cf_with_multiple_missing_values(self) -> None: original = Variable(["t"], [0, 1, 2], {"missing_value": np.array([0, 1])}) expected = Variable(["t"], [np.nan, np.nan, 2], {}) with warnings.catch_warnings(record=True) as w: @@ -232,7 +232,7 @@ def test_decode_cf_with_multiple_missing_values(self): assert_identical(expected, actual) assert "has multiple fill" in str(w[0].message) - def test_decode_cf_with_drop_variables(self): + def test_decode_cf_with_drop_variables(self) -> None: original = Dataset( { "t": ("t", [0, 1, 2], {"units": "days since 2000-01-01"}), @@ -262,13 +262,13 @@ def test_decode_cf_with_drop_variables(self): assert_identical(expected, actual2) @pytest.mark.filterwarnings("ignore:Ambiguous reference date string") - def test_invalid_time_units_raises_eagerly(self): + def test_invalid_time_units_raises_eagerly(self) -> None: ds = Dataset({"time": ("time", [0, 1], {"units": "foobar since 123"})}) with pytest.raises(ValueError, match=r"unable to decode time"): decode_cf(ds) @requires_cftime - def test_dataset_repr_with_netcdf4_datetimes(self): + def test_dataset_repr_with_netcdf4_datetimes(self) -> None: # regression test for #347 attrs = {"units": "days since 0001-01-01", "calendar": "noleap"} with warnings.catch_warnings(): @@ -281,7 +281,7 @@ def test_dataset_repr_with_netcdf4_datetimes(self): assert "(time) datetime64[ns]" in repr(ds) @requires_cftime - def test_decode_cf_datetime_transition_to_invalid(self): + def test_decode_cf_datetime_transition_to_invalid(self) -> None: # manually create dataset with not-decoded date from datetime import datetime @@ -297,7 +297,7 @@ def test_decode_cf_datetime_transition_to_invalid(self): assert_array_equal(ds_decoded.time.values, expected) @requires_dask - def test_decode_cf_with_dask(self): + def test_decode_cf_with_dask(self) -> None: import dask.array as da original = Dataset( @@ -319,7 +319,7 @@ def test_decode_cf_with_dask(self): assert_identical(decoded, conventions.decode_cf(original).compute()) @requires_dask - def test_decode_dask_times(self): + def test_decode_dask_times(self) -> None: original = Dataset.from_dict( { "coords": {}, @@ -338,7 +338,7 @@ def test_decode_dask_times(self): conventions.decode_cf(original).chunk(), ) - def test_decode_cf_time_kwargs(self): + def test_decode_cf_time_kwargs(self) -> None: ds = Dataset.from_dict( { "coords": { @@ -401,18 +401,18 @@ def roundtrip( yield open_dataset(store, **open_kwargs) @pytest.mark.skip("cannot roundtrip coordinates yet for CFEncodedInMemoryStore") - def test_roundtrip_coordinates(self): + def test_roundtrip_coordinates(self) -> None: pass - def test_invalid_dataarray_names_raise(self): + def test_invalid_dataarray_names_raise(self) -> None: # only relevant for on-disk file formats pass - def test_encoding_kwarg(self): + def test_encoding_kwarg(self) -> None: # we haven't bothered to raise errors yet for unexpected encodings in # this test dummy pass - def test_encoding_kwarg_fixed_width_string(self): + def test_encoding_kwarg_fixed_width_string(self) -> None: # CFEncodedInMemoryStore doesn't support explicit string encodings. pass diff --git a/xarray/tests/test_cupy.py b/xarray/tests/test_cupy.py index 69f43d99139..e8f35e12ac6 100644 --- a/xarray/tests/test_cupy.py +++ b/xarray/tests/test_cupy.py @@ -39,18 +39,18 @@ def toy_weather_data(): return ds -def test_cupy_import(): +def test_cupy_import() -> None: """Check the import worked.""" assert cp -def test_check_data_stays_on_gpu(toy_weather_data): +def test_check_data_stays_on_gpu(toy_weather_data) -> None: """Perform some operations and check the data stays on the GPU.""" freeze = (toy_weather_data["tmin"] <= 0).groupby("time.month").mean("time") assert isinstance(freeze.data, cp.ndarray) -def test_where(): +def test_where() -> None: from xarray.core.duck_array_ops import where data = cp.zeros(10) diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index ab0d1d9f22c..ef1ce50d6ea 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -132,13 +132,13 @@ def test_dask_distributed_read_netcdf_integration_test( @requires_zarr @pytest.mark.parametrize("consolidated", [True, False]) @pytest.mark.parametrize("compute", [True, False]) -def test_dask_distributed_zarr_integration_test(loop, consolidated, compute): +def test_dask_distributed_zarr_integration_test(loop, consolidated, compute) -> None: if consolidated: pytest.importorskip("zarr", minversion="2.2.1.dev2") write_kwargs = {"consolidated": True} read_kwargs = {"backend_kwargs": {"consolidated": True}} else: - write_kwargs = read_kwargs = {} + write_kwargs = read_kwargs = {} # type: ignore chunks = {"dim1": 4, "dim2": 3, "dim3": 5} with cluster() as (s, [a, b]): with Client(s["address"], loop=loop): @@ -160,7 +160,7 @@ def test_dask_distributed_zarr_integration_test(loop, consolidated, compute): @requires_rasterio -def test_dask_distributed_rasterio_integration_test(loop): +def test_dask_distributed_rasterio_integration_test(loop) -> None: with create_tmp_geotiff() as (tmp_file, expected): with cluster() as (s, [a, b]): with Client(s["address"], loop=loop): @@ -172,7 +172,7 @@ def test_dask_distributed_rasterio_integration_test(loop): @requires_cfgrib @pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") -def test_dask_distributed_cfgrib_integration_test(loop): +def test_dask_distributed_cfgrib_integration_test(loop) -> None: with cluster() as (s, [a, b]): with Client(s["address"], loop=loop): with open_example_dataset( @@ -185,7 +185,7 @@ def test_dask_distributed_cfgrib_integration_test(loop): @gen_cluster(client=True) -async def test_async(c, s, a, b): +async def test_async(c, s, a, b) -> None: x = create_test_data() assert not dask.is_dask_collection(x) y = x.chunk({"dim2": 4}) + 10 @@ -212,12 +212,12 @@ async def test_async(c, s, a, b): assert s.tasks -def test_hdf5_lock(): +def test_hdf5_lock() -> None: assert isinstance(HDF5_LOCK, dask.utils.SerializableLock) @gen_cluster(client=True) -async def test_serializable_locks(c, s, a, b): +async def test_serializable_locks(c, s, a, b) -> None: def f(x, lock=None): with lock: return x + 1 diff --git a/xarray/tests/test_extensions.py b/xarray/tests/test_extensions.py index 2d9fa11dda3..5ca3a35e238 100644 --- a/xarray/tests/test_extensions.py +++ b/xarray/tests/test_extensions.py @@ -17,7 +17,7 @@ def __init__(self, xarray_obj): class TestAccessor: - def test_register(self): + def test_register(self) -> None: @xr.register_dataset_accessor("demo") @xr.register_dataarray_accessor("demo") class DemoAccessor: @@ -41,12 +41,13 @@ def foo(self): # check descriptor assert ds.demo.__doc__ == "Demo accessor." - assert xr.Dataset.demo.__doc__ == "Demo accessor." - assert isinstance(ds.demo, DemoAccessor) - assert xr.Dataset.demo is DemoAccessor + # TODO: typing doesn't seem to work with accessors + assert xr.Dataset.demo.__doc__ == "Demo accessor." # type: ignore + assert isinstance(ds.demo, DemoAccessor) # type: ignore + assert xr.Dataset.demo is DemoAccessor # type: ignore # ensure we can remove it - del xr.Dataset.demo + del xr.Dataset.demo # type: ignore assert not hasattr(xr.Dataset, "demo") with pytest.warns(Warning, match="overriding a preexisting attribute"): @@ -58,7 +59,7 @@ class Foo: # it didn't get registered again assert not hasattr(xr.Dataset, "demo") - def test_pickle_dataset(self): + def test_pickle_dataset(self) -> None: ds = xr.Dataset() ds_restored = pickle.loads(pickle.dumps(ds)) assert_identical(ds, ds_restored) @@ -70,13 +71,13 @@ def test_pickle_dataset(self): assert_identical(ds, ds_restored) assert ds_restored.example_accessor.value == "foo" - def test_pickle_dataarray(self): + def test_pickle_dataarray(self) -> None: array = xr.Dataset() assert array.example_accessor is array.example_accessor array_restored = pickle.loads(pickle.dumps(array)) assert_identical(array, array_restored) - def test_broken_accessor(self): + def test_broken_accessor(self) -> None: # regression test for GH933 @xr.register_dataset_accessor("stupid_accessor") diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index d5e7d2ee232..9e53cac3aa6 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd import pytest +from numpy.core import defchararray import xarray as xr from xarray.core import formatting @@ -12,7 +13,7 @@ class TestFormatting: - def test_get_indexer_at_least_n_items(self): + def test_get_indexer_at_least_n_items(self) -> None: cases = [ ((20,), (slice(10),), (slice(-10, None),)), ((3, 20), (0, slice(10)), (-1, slice(-10, None))), @@ -43,7 +44,7 @@ def test_get_indexer_at_least_n_items(self): actual = formatting._get_indexer_at_least_n_items(shape, 10, from_end=True) assert end_expected == actual - def test_first_n_items(self): + def test_first_n_items(self) -> None: array = np.arange(100).reshape(10, 5, 2) for n in [3, 10, 13, 100, 200]: actual = formatting.first_n_items(array, n) @@ -53,7 +54,7 @@ def test_first_n_items(self): with pytest.raises(ValueError, match=r"at least one item"): formatting.first_n_items(array, 0) - def test_last_n_items(self): + def test_last_n_items(self) -> None: array = np.arange(100).reshape(10, 5, 2) for n in [3, 10, 13, 100, 200]: actual = formatting.last_n_items(array, n) @@ -63,7 +64,7 @@ def test_last_n_items(self): with pytest.raises(ValueError, match=r"at least one item"): formatting.first_n_items(array, 0) - def test_last_item(self): + def test_last_item(self) -> None: array = np.arange(100) reshape = ((10, 10), (1, 100), (2, 2, 5, 5)) @@ -73,7 +74,7 @@ def test_last_item(self): result = formatting.last_item(array.reshape(r)) assert result == expected - def test_format_item(self): + def test_format_item(self) -> None: cases = [ (pd.Timestamp("2000-01-01T12"), "2000-01-01T12:00:00"), (pd.Timestamp("2000-01-01"), "2000-01-01"), @@ -94,7 +95,7 @@ def test_format_item(self): actual = formatting.format_item(item) assert expected == actual - def test_format_items(self): + def test_format_items(self) -> None: cases = [ (np.arange(4) * np.timedelta64(1, "D"), "0 days 1 days 2 days 3 days"), ( @@ -116,7 +117,7 @@ def test_format_items(self): actual = " ".join(formatting.format_items(item)) assert expected == actual - def test_format_array_flat(self): + def test_format_array_flat(self) -> None: actual = formatting.format_array_flat(np.arange(100), 2) expected = "..." assert expected == actual @@ -180,14 +181,14 @@ def test_format_array_flat(self): expected = "'hello world hello..." assert expected == actual - def test_pretty_print(self): + def test_pretty_print(self) -> None: assert formatting.pretty_print("abcdefghij", 8) == "abcde..." assert formatting.pretty_print("ß", 1) == "ß" - def test_maybe_truncate(self): + def test_maybe_truncate(self) -> None: assert formatting.maybe_truncate("ß", 10) == "ß" - def test_format_timestamp_out_of_bounds(self): + def test_format_timestamp_out_of_bounds(self) -> None: from datetime import datetime date = datetime(1300, 12, 1) @@ -200,7 +201,7 @@ def test_format_timestamp_out_of_bounds(self): result = formatting.format_timestamp(date) assert result == expected - def test_attribute_repr(self): + def test_attribute_repr(self) -> None: short = formatting.summarize_attr("key", "Short string") long = formatting.summarize_attr("key", 100 * "Very long string ") newlines = formatting.summarize_attr("key", "\n\n\n") @@ -211,7 +212,7 @@ def test_attribute_repr(self): assert "\n" not in newlines assert "\t" not in tabs - def test_diff_array_repr(self): + def test_diff_array_repr(self) -> None: da_a = xr.DataArray( np.array([[1, 2, 3], [4, 5, 6]], dtype="int64"), dims=("x", "y"), @@ -291,7 +292,7 @@ def test_diff_array_repr(self): assert actual == expected.replace(", dtype=int64", "") @pytest.mark.filterwarnings("error") - def test_diff_attrs_repr_with_array(self): + def test_diff_attrs_repr_with_array(self) -> None: attrs_a = {"attr": np.array([0, 1])} attrs_b = {"attr": 1} @@ -305,7 +306,7 @@ def test_diff_attrs_repr_with_array(self): actual = formatting.diff_attrs_repr(attrs_a, attrs_b, "equals") assert expected == actual - attrs_b = {"attr": np.array([-3, 5])} + attrs_c = {"attr": np.array([-3, 5])} expected = dedent( """\ Differing attributes: @@ -313,11 +314,11 @@ def test_diff_attrs_repr_with_array(self): R attr: [-3 5] """ ).strip() - actual = formatting.diff_attrs_repr(attrs_a, attrs_b, "equals") + actual = formatting.diff_attrs_repr(attrs_a, attrs_c, "equals") assert expected == actual # should not raise a warning - attrs_b = {"attr": np.array([0, 1, 2])} + attrs_c = {"attr": np.array([0, 1, 2])} expected = dedent( """\ Differing attributes: @@ -325,10 +326,10 @@ def test_diff_attrs_repr_with_array(self): R attr: [0 1 2] """ ).strip() - actual = formatting.diff_attrs_repr(attrs_a, attrs_b, "equals") + actual = formatting.diff_attrs_repr(attrs_a, attrs_c, "equals") assert expected == actual - def test_diff_dataset_repr(self): + def test_diff_dataset_repr(self) -> None: ds_a = xr.Dataset( data_vars={ "var1": (("x", "y"), np.array([[1, 2, 3], [4, 5, 6]], dtype="int64")), @@ -380,7 +381,7 @@ def test_diff_dataset_repr(self): actual = formatting.diff_dataset_repr(ds_a, ds_b, "identical") assert actual == expected - def test_array_repr(self): + def test_array_repr(self) -> None: ds = xr.Dataset(coords={"foo": [1, 2, 3], "bar": [1, 2, 3]}) ds[(1, 2)] = xr.DataArray([0], dims="test") actual = formatting.array_repr(ds[(1, 2)]) @@ -404,7 +405,7 @@ def test_array_repr(self): assert actual == expected - def test_array_repr_variable(self): + def test_array_repr_variable(self) -> None: var = xr.Variable("x", [0, 1]) formatting.array_repr(var) @@ -413,7 +414,7 @@ def test_array_repr_variable(self): formatting.array_repr(var) -def test_inline_variable_array_repr_custom_repr(): +def test_inline_variable_array_repr_custom_repr() -> None: class CustomArray: def __init__(self, value, attr): self.value = value @@ -450,7 +451,7 @@ def ndim(self): assert actual == value._repr_inline_(max_width) -def test_set_numpy_options(): +def test_set_numpy_options() -> None: original_options = np.get_printoptions() with formatting.set_numpy_options(threshold=10): assert len(repr(np.arange(500))) < 200 @@ -458,7 +459,7 @@ def test_set_numpy_options(): assert np.get_printoptions() == original_options -def test_short_numpy_repr(): +def test_short_numpy_repr() -> None: cases = [ np.random.randn(500), np.random.randn(20, 20), @@ -474,7 +475,7 @@ def test_short_numpy_repr(): assert num_lines < 30 -def test_large_array_repr_length(): +def test_large_array_repr_length() -> None: da = xr.DataArray(np.random.randn(100, 5, 1)) @@ -483,7 +484,7 @@ def test_large_array_repr_length(): @requires_netCDF4 -def test_repr_file_collapsed(tmp_path): +def test_repr_file_collapsed(tmp_path) -> None: arr = xr.DataArray(np.arange(300), dims="test") arr.to_netcdf(tmp_path / "test.nc", engine="netcdf4") @@ -505,11 +506,11 @@ def test_repr_file_collapsed(tmp_path): "display_max_rows, n_vars, n_attr", [(50, 40, 30), (35, 40, 30), (11, 40, 30), (1, 40, 30)], ) -def test__mapping_repr(display_max_rows, n_vars, n_attr): +def test__mapping_repr(display_max_rows, n_vars, n_attr) -> None: long_name = "long_name" - a = np.core.defchararray.add(long_name, np.arange(0, n_vars).astype(str)) - b = np.core.defchararray.add("attr_", np.arange(0, n_attr).astype(str)) - c = np.core.defchararray.add("coord", np.arange(0, n_vars).astype(str)) + a = defchararray.add(long_name, np.arange(0, n_vars).astype(str)) + b = defchararray.add("attr_", np.arange(0, n_attr).astype(str)) + c = defchararray.add("coord", np.arange(0, n_vars).astype(str)) attrs = {k: 2 for k in b} coords = {_c: np.array([0, 1]) for _c in c} data_vars = dict() diff --git a/xarray/tests/test_formatting_html.py b/xarray/tests/test_formatting_html.py index 09c6fa0cf3c..4ee80f65027 100644 --- a/xarray/tests/test_formatting_html.py +++ b/xarray/tests/test_formatting_html.py @@ -1,3 +1,5 @@ +from typing import Dict, List + import numpy as np import pandas as pd import pytest @@ -44,47 +46,49 @@ def dataset(): ) -def test_short_data_repr_html(dataarray): +def test_short_data_repr_html(dataarray) -> None: data_repr = fh.short_data_repr_html(dataarray) assert data_repr.startswith("
array")
 
 
-def test_short_data_repr_html_non_str_keys(dataset):
+def test_short_data_repr_html_non_str_keys(dataset) -> None:
     ds = dataset.assign({2: lambda x: x["tmin"]})
     fh.dataset_repr(ds)
 
 
-def test_short_data_repr_html_dask(dask_dataarray):
+def test_short_data_repr_html_dask(dask_dataarray) -> None:
     assert hasattr(dask_dataarray.data, "_repr_html_")
     data_repr = fh.short_data_repr_html(dask_dataarray)
     assert data_repr == dask_dataarray.data._repr_html_()
 
 
-def test_format_dims_no_dims():
-    dims, coord_names = {}, []
+def test_format_dims_no_dims() -> None:
+    dims: Dict = {}
+    coord_names: List = []
     formatted = fh.format_dims(dims, coord_names)
     assert formatted == ""
 
 
-def test_format_dims_unsafe_dim_name():
-    dims, coord_names = {"": 3, "y": 2}, []
+def test_format_dims_unsafe_dim_name() -> None:
+    dims = {"": 3, "y": 2}
+    coord_names: List = []
     formatted = fh.format_dims(dims, coord_names)
     assert "<x>" in formatted
 
 
-def test_format_dims_non_index():
+def test_format_dims_non_index() -> None:
     dims, coord_names = {"x": 3, "y": 2}, ["time"]
     formatted = fh.format_dims(dims, coord_names)
     assert "class='xr-has-index'" not in formatted
 
 
-def test_format_dims_index():
+def test_format_dims_index() -> None:
     dims, coord_names = {"x": 3, "y": 2}, ["x"]
     formatted = fh.format_dims(dims, coord_names)
     assert "class='xr-has-index'" in formatted
 
 
-def test_summarize_attrs_with_unsafe_attr_name_and_value():
+def test_summarize_attrs_with_unsafe_attr_name_and_value() -> None:
     attrs = {"": 3, "y": ""}
     formatted = fh.summarize_attrs(attrs)
     assert "
<x> :
" in formatted @@ -93,7 +97,7 @@ def test_summarize_attrs_with_unsafe_attr_name_and_value(): assert "
<pd.DataFrame>
" in formatted -def test_repr_of_dataarray(dataarray): +def test_repr_of_dataarray(dataarray) -> None: formatted = fh.array_repr(dataarray) assert "dim_0" in formatted # has an expanded data section @@ -115,7 +119,7 @@ def test_repr_of_dataarray(dataarray): ) -def test_summary_of_multiindex_coord(multiindex): +def test_summary_of_multiindex_coord(multiindex) -> None: idx = multiindex.x.variable.to_index_variable() formatted = fh._summarize_coord_multiindex("foo", idx) assert "(level_1, level_2)" in formatted @@ -123,12 +127,12 @@ def test_summary_of_multiindex_coord(multiindex): assert "foo" in formatted -def test_repr_of_multiindex(multiindex): +def test_repr_of_multiindex(multiindex) -> None: formatted = fh.dataset_repr(multiindex) assert "(x)" in formatted -def test_repr_of_dataset(dataset): +def test_repr_of_dataset(dataset) -> None: formatted = fh.dataset_repr(dataset) # coords, attrs, and data_vars are expanded assert ( @@ -152,14 +156,14 @@ def test_repr_of_dataset(dataset): assert "<IA>" in formatted -def test_repr_text_fallback(dataset): +def test_repr_text_fallback(dataset) -> None: formatted = fh.dataset_repr(dataset) # Just test that the "pre" block used for fallback to plain text is present. assert "
" in formatted
 
 
-def test_variable_repr_html():
+def test_variable_repr_html() -> None:
     v = xr.Variable(["time", "x"], [[1, 2, 3], [4, 5, 6]], {"foo": "bar"})
     assert hasattr(v, "_repr_html_")
     with xr.set_options(display_style="html"):
@@ -171,7 +175,7 @@ def test_variable_repr_html():
     assert "xarray.Variable" in html
 
 
-def test_repr_of_nonstr_dataset(dataset):
+def test_repr_of_nonstr_dataset(dataset) -> None:
     ds = dataset.copy()
     ds.attrs[1] = "Test value"
     ds[2] = ds["tmin"]
@@ -180,7 +184,7 @@ def test_repr_of_nonstr_dataset(dataset):
     assert "
2" in formatted -def test_repr_of_nonstr_dataarray(dataarray): +def test_repr_of_nonstr_dataarray(dataarray) -> None: da = dataarray.rename(dim_0=15) da.attrs[1] = "value" formatted = fh.array_repr(da) @@ -188,7 +192,7 @@ def test_repr_of_nonstr_dataarray(dataarray): assert "
  • 15: 4
  • " in formatted -def test_nonstr_variable_repr_html(): +def test_nonstr_variable_repr_html() -> None: v = xr.Variable(["time", 10], [[1, 2, 3], [4, 5, 6]], {22: "bar"}) assert hasattr(v, "_repr_html_") with xr.set_options(display_style="html"): diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index ee77865dd24..d48726e8304 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -34,7 +34,7 @@ def array(dataset): return dataset["foo"] -def test_consolidate_slices(): +def test_consolidate_slices() -> None: assert _consolidate_slices([slice(3), slice(3, 5)]) == [slice(5)] assert _consolidate_slices([slice(2, 3), slice(3, 6)]) == [slice(2, 6)] @@ -47,7 +47,7 @@ def test_consolidate_slices(): _consolidate_slices([slice(3), 4]) -def test_groupby_dims_property(dataset): +def test_groupby_dims_property(dataset) -> None: assert dataset.groupby("x").dims == dataset.isel(x=1).dims assert dataset.groupby("y").dims == dataset.isel(y=1).dims @@ -55,7 +55,7 @@ def test_groupby_dims_property(dataset): assert stacked.groupby("xy").dims == stacked.isel(xy=0).dims -def test_multi_index_groupby_map(dataset): +def test_multi_index_groupby_map(dataset) -> None: # regression test for GH873 ds = dataset.isel(z=1, drop=True)[["foo"]] expected = 2 * ds @@ -68,7 +68,7 @@ def test_multi_index_groupby_map(dataset): assert_equal(expected, actual) -def test_multi_index_groupby_sum(): +def test_multi_index_groupby_sum() -> None: # regression test for GH873 ds = xr.Dataset( {"foo": (("x", "y", "z"), np.ones((3, 4, 2)))}, @@ -79,7 +79,7 @@ def test_multi_index_groupby_sum(): assert_equal(expected, actual) -def test_groupby_da_datetime(): +def test_groupby_da_datetime() -> None: # test groupby with a DataArray of dtype datetime for GH1132 # create test data times = pd.date_range("2000-01-01", periods=4) @@ -99,7 +99,7 @@ def test_groupby_da_datetime(): assert_equal(expected, actual) -def test_groupby_duplicate_coordinate_labels(): +def test_groupby_duplicate_coordinate_labels() -> None: # fix for http://stackoverflow.com/questions/38065129 array = xr.DataArray([1, 2, 3], [("x", [1, 1, 2])]) expected = xr.DataArray([3, 3], [("x", [1, 2])]) @@ -107,7 +107,7 @@ def test_groupby_duplicate_coordinate_labels(): assert_equal(expected, actual) -def test_groupby_input_mutation(): +def test_groupby_input_mutation() -> None: # regression test for GH2153 array = xr.DataArray([1, 2, 3], [("x", [2, 2, 1])]) array_copy = array.copy() @@ -124,7 +124,7 @@ def test_groupby_input_mutation(): xr.Dataset({"foo": ("x", [1, 2, 3, 4, 5, 6])}, {"x": [1, 1, 1, 2, 2, 2]}), ], ) -def test_groupby_map_shrink_groups(obj): +def test_groupby_map_shrink_groups(obj) -> None: expected = obj.isel(x=[0, 1, 3, 4]) actual = obj.groupby("x").map(lambda f: f.isel(x=[0, 1])) assert_identical(expected, actual) @@ -137,7 +137,7 @@ def test_groupby_map_shrink_groups(obj): xr.Dataset({"foo": ("x", [1, 2, 3])}, {"x": [1, 2, 2]}), ], ) -def test_groupby_map_change_group_size(obj): +def test_groupby_map_change_group_size(obj) -> None: def func(group): if group.sizes["x"] == 1: result = group.isel(x=[0, 0]) @@ -150,7 +150,7 @@ def func(group): assert_identical(expected, actual) -def test_da_groupby_map_func_args(): +def test_da_groupby_map_func_args() -> None: def func(arg1, arg2, arg3=0): return arg1 + arg2 + arg3 @@ -160,7 +160,7 @@ def func(arg1, arg2, arg3=0): assert_identical(expected, actual) -def test_ds_groupby_map_func_args(): +def test_ds_groupby_map_func_args() -> None: def func(arg1, arg2, arg3=0): return arg1 + arg2 + arg3 @@ -170,7 +170,7 @@ def func(arg1, arg2, arg3=0): assert_identical(expected, actual) -def test_da_groupby_empty(): +def test_da_groupby_empty() -> None: empty_array = xr.DataArray([], dims="dim") @@ -178,7 +178,7 @@ def test_da_groupby_empty(): empty_array.groupby("dim") -def test_da_groupby_quantile(): +def test_da_groupby_quantile() -> None: array = xr.DataArray( data=[1, 2, 3, 4, 5, 6], coords={"x": [1, 1, 1, 2, 2, 2]}, dims="x" @@ -274,7 +274,7 @@ def test_da_groupby_quantile(): assert_identical(expected, actual) -def test_ds_groupby_quantile(): +def test_ds_groupby_quantile() -> None: ds = xr.Dataset( data_vars={"a": ("x", [1, 2, 3, 4, 5, 6])}, coords={"x": [1, 1, 1, 2, 2, 2]} ) @@ -368,7 +368,7 @@ def test_ds_groupby_quantile(): assert_identical(expected, actual) -def test_da_groupby_assign_coords(): +def test_da_groupby_assign_coords() -> None: actual = xr.DataArray( [[3, 4, 5], [6, 7, 8]], dims=["y", "x"], coords={"y": range(2), "x": range(3)} ) @@ -395,7 +395,7 @@ def test_da_groupby_assign_coords(): @pytest.mark.parametrize("dim", ["x", "y", "z", "month"]) @pytest.mark.parametrize("obj", [repr_da, repr_da.to_dataset(name="a")]) -def test_groupby_repr(obj, dim): +def test_groupby_repr(obj, dim) -> None: actual = repr(obj.groupby(dim)) expected = f"{obj.__class__.__name__}GroupBy" expected += ", grouped over %r" % dim @@ -412,7 +412,7 @@ def test_groupby_repr(obj, dim): @pytest.mark.parametrize("obj", [repr_da, repr_da.to_dataset(name="a")]) -def test_groupby_repr_datetime(obj): +def test_groupby_repr_datetime(obj) -> None: actual = repr(obj.groupby("t.month")) expected = f"{obj.__class__.__name__}GroupBy" expected += ", grouped over 'month'" @@ -421,7 +421,7 @@ def test_groupby_repr_datetime(obj): assert actual == expected -def test_groupby_drops_nans(): +def test_groupby_drops_nans() -> None: # GH2383 # nan in 2D data variable (requires stacking) ds = xr.Dataset( @@ -463,9 +463,9 @@ def test_groupby_drops_nans(): # NaN in non-dimensional coordinate array = xr.DataArray([1, 2, 3], [("x", [1, 2, 3])]) array["x1"] = ("x", [1, 1, np.nan]) - expected = xr.DataArray(3, [("x1", [1])]) + expected_da = xr.DataArray(3, [("x1", [1])]) actual = array.groupby("x1").sum() - assert_equal(expected, actual) + assert_equal(expected_da, actual) # NaT in non-dimensional coordinate array["t"] = ( @@ -476,18 +476,18 @@ def test_groupby_drops_nans(): np.datetime64("NaT"), ], ) - expected = xr.DataArray(3, [("t", [np.datetime64("2001-01-01")])]) + expected_da = xr.DataArray(3, [("t", [np.datetime64("2001-01-01")])]) actual = array.groupby("t").sum() - assert_equal(expected, actual) + assert_equal(expected_da, actual) # test for repeated coordinate labels array = xr.DataArray([0, 1, 2, 4, 3, 4], [("x", [np.nan, 1, 1, np.nan, 2, np.nan])]) - expected = xr.DataArray([3, 3], [("x", [1, 2])]) + expected_da = xr.DataArray([3, 3], [("x", [1, 2])]) actual = array.groupby("x").sum() - assert_equal(expected, actual) + assert_equal(expected_da, actual) -def test_groupby_grouping_errors(): +def test_groupby_grouping_errors() -> None: dataset = xr.Dataset({"foo": ("x", [1, 1, 1])}, {"x": [1, 2, 3]}) with pytest.raises( ValueError, match=r"None of the data falls within bins with edges" @@ -512,7 +512,7 @@ def test_groupby_grouping_errors(): dataset.to_array().groupby(dataset.foo * np.nan) -def test_groupby_reduce_dimension_error(array): +def test_groupby_reduce_dimension_error(array) -> None: grouped = array.groupby("y") with pytest.raises(ValueError, match=r"cannot reduce over dimensions"): grouped.mean() @@ -530,17 +530,17 @@ def test_groupby_reduce_dimension_error(array): assert_allclose(array.mean(["x", "z"]), grouped.reduce(np.mean, ["x", "z"])) -def test_groupby_multiple_string_args(array): +def test_groupby_multiple_string_args(array) -> None: with pytest.raises(TypeError): array.groupby("x", "y") -def test_groupby_bins_timeseries(): +def test_groupby_bins_timeseries() -> None: ds = xr.Dataset() ds["time"] = xr.DataArray( pd.date_range("2010-08-01", "2010-08-15", freq="15min"), dims="time" ) - ds["val"] = xr.DataArray(np.ones(*ds["time"].shape), dims="time") + ds["val"] = xr.DataArray(np.ones(ds["time"].shape), dims="time") time_bins = pd.date_range(start="2010-08-01", end="2010-08-15", freq="24H") actual = ds.groupby_bins("time", time_bins).sum() expected = xr.DataArray( @@ -551,7 +551,7 @@ def test_groupby_bins_timeseries(): assert_identical(actual, expected) -def test_groupby_none_group_name(): +def test_groupby_none_group_name() -> None: # GH158 # xarray should not fail if a DataArray's name attribute is None @@ -563,7 +563,7 @@ def test_groupby_none_group_name(): assert "group" in mean.dims -def test_groupby_getitem(dataset): +def test_groupby_getitem(dataset) -> None: assert_identical(dataset.sel(x="a"), dataset.groupby("x")["a"]) assert_identical(dataset.sel(z=1), dataset.groupby("z")[1]) @@ -576,7 +576,7 @@ def test_groupby_getitem(dataset): assert_identical(expected, actual) -def test_groupby_dataset(): +def test_groupby_dataset() -> None: data = Dataset( {"z": (["x", "y"], np.random.randn(3, 5))}, {"x": ("x", list("abc")), "c": ("x", [0, 1, 0]), "y": range(5)}, @@ -602,7 +602,7 @@ def identity(x): assert_equal(data, actual) -def test_groupby_dataset_returns_new_type(): +def test_groupby_dataset_returns_new_type() -> None: data = Dataset({"z": (["x", "y"], np.random.randn(3, 5))}) actual = data.groupby("x").map(lambda ds: ds["z"]) @@ -610,11 +610,11 @@ def test_groupby_dataset_returns_new_type(): assert_identical(expected, actual) actual = data["z"].groupby("x").map(lambda x: x.to_dataset()) - expected = data - assert_identical(expected, actual) + expected_ds = data + assert_identical(expected_ds, actual) -def test_groupby_dataset_iter(): +def test_groupby_dataset_iter() -> None: data = create_test_data() for n, (t, sub) in enumerate(list(data.groupby("dim1"))[:3]): assert data["dim1"][n] == t @@ -623,7 +623,7 @@ def test_groupby_dataset_iter(): assert_equal(data["var3"][:, n], sub["var3"]) -def test_groupby_dataset_errors(): +def test_groupby_dataset_errors() -> None: data = create_test_data() with pytest.raises(TypeError, match=r"`group` must be"): data.groupby(np.arange(10)) @@ -633,7 +633,7 @@ def test_groupby_dataset_errors(): data.groupby(data.coords["dim1"].to_index()) -def test_groupby_dataset_reduce(): +def test_groupby_dataset_reduce() -> None: data = Dataset( { "xy": (["x", "y"], np.random.randn(3, 4)), @@ -663,7 +663,7 @@ def test_groupby_dataset_reduce(): assert_allclose(expected, actual) -def test_groupby_dataset_math(): +def test_groupby_dataset_math() -> None: def reorder_dims(x): return x.transpose("dim1", "dim2", "dim3", "time") @@ -719,7 +719,7 @@ def reorder_dims(x): ds + ds.groupby("time.month") -def test_groupby_dataset_math_virtual(): +def test_groupby_dataset_math_virtual() -> None: ds = Dataset({"x": ("t", [1, 2, 3])}, {"t": pd.date_range("20100101", periods=3)}) grouped = ds.groupby("t.day") actual = grouped - grouped.mean(...) @@ -727,7 +727,7 @@ def test_groupby_dataset_math_virtual(): assert_identical(actual, expected) -def test_groupby_dataset_nan(): +def test_groupby_dataset_nan() -> None: # nan should be excluded from groupby ds = Dataset({"foo": ("x", [1, 2, 3, 4])}, {"bar": ("x", [1, 1, 2, np.nan])}) actual = ds.groupby("bar").mean(...) @@ -735,7 +735,7 @@ def test_groupby_dataset_nan(): assert_identical(actual, expected) -def test_groupby_dataset_order(): +def test_groupby_dataset_order() -> None: # groupby should preserve variables order ds = Dataset() for vn in ["a", "b", "c"]: diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index c8ba72a253f..18f76df765d 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -7,7 +7,7 @@ from xarray.core.variable import IndexVariable -def test_asarray_tuplesafe(): +def test_asarray_tuplesafe() -> None: res = _asarray_tuplesafe(("a", 1)) assert isinstance(res, np.ndarray) assert res.ndim == 0 @@ -20,14 +20,14 @@ def test_asarray_tuplesafe(): class TestPandasIndex: - def test_constructor(self): + def test_constructor(self) -> None: pd_idx = pd.Index([1, 2, 3]) index = PandasIndex(pd_idx, "x") assert index.index is pd_idx assert index.dim == "x" - def test_from_variables(self): + def test_from_variables(self) -> None: var = xr.Variable( "x", [1, 2, 3], attrs={"unit": "m"}, encoding={"dtype": np.int32} ) @@ -46,7 +46,7 @@ def test_from_variables(self): ): PandasIndex.from_variables({"foo": var2}) - def test_from_pandas_index(self): + def test_from_pandas_index(self) -> None: pd_idx = pd.Index([1, 2, 3], name="foo") index, index_vars = PandasIndex.from_pandas_index(pd_idx, "x") @@ -68,7 +68,7 @@ def to_pandas_index(self): index = PandasIndex(pd_idx, "x") assert index.to_pandas_index() is pd_idx - def test_query(self): + def test_query(self) -> None: # TODO: add tests that aren't just for edge cases index = PandasIndex(pd.Index([1, 2, 3]), "x") with pytest.raises(KeyError, match=r"not all values found"): @@ -78,7 +78,7 @@ def test_query(self): with pytest.raises(ValueError, match=r"does not have a MultiIndex"): index.query({"x": {"one": 0}}) - def test_query_datetime(self): + def test_query_datetime(self) -> None: index = PandasIndex( pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"]), "x" ) @@ -89,7 +89,7 @@ def test_query_datetime(self): actual = index.query({"x": index.to_pandas_index().to_numpy()[1]}) assert actual == expected - def test_query_unsorted_datetime_index_raises(self): + def test_query_unsorted_datetime_index_raises(self) -> None: index = PandasIndex(pd.to_datetime(["2001", "2000", "2002"]), "x") with pytest.raises(KeyError): # pandas will try to convert this into an array indexer. We should @@ -97,26 +97,26 @@ def test_query_unsorted_datetime_index_raises(self): # slice is always a view. index.query({"x": slice("2001", "2002")}) - def test_equals(self): + def test_equals(self) -> None: index1 = PandasIndex([1, 2, 3], "x") index2 = PandasIndex([1, 2, 3], "x") assert index1.equals(index2) is True - def test_union(self): + def test_union(self) -> None: index1 = PandasIndex([1, 2, 3], "x") index2 = PandasIndex([4, 5, 6], "y") actual = index1.union(index2) assert actual.index.equals(pd.Index([1, 2, 3, 4, 5, 6])) assert actual.dim == "x" - def test_intersection(self): + def test_intersection(self) -> None: index1 = PandasIndex([1, 2, 3], "x") index2 = PandasIndex([2, 3, 4], "y") actual = index1.intersection(index2) assert actual.index.equals(pd.Index([2, 3])) assert actual.dim == "x" - def test_copy(self): + def test_copy(self) -> None: expected = PandasIndex([1, 2, 3], "x") actual = expected.copy() @@ -124,7 +124,7 @@ def test_copy(self): assert actual.index is not expected.index assert actual.dim == expected.dim - def test_getitem(self): + def test_getitem(self) -> None: pd_idx = pd.Index([1, 2, 3]) expected = PandasIndex(pd_idx, "x") actual = expected[1:] @@ -134,7 +134,7 @@ def test_getitem(self): class TestPandasMultiIndex: - def test_from_variables(self): + def test_from_variables(self) -> None: v_level1 = xr.Variable( "x", [1, 2, 3], attrs={"unit": "m"}, encoding={"dtype": np.int32} ) @@ -165,7 +165,7 @@ def test_from_variables(self): with pytest.raises(ValueError, match=r"unmatched dimensions for variables.*"): PandasMultiIndex.from_variables({"level1": v_level1, "level3": v_level3}) - def test_from_pandas_index(self): + def test_from_pandas_index(self) -> None: pd_idx = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]], names=("foo", "bar")) index, index_vars = PandasMultiIndex.from_pandas_index(pd_idx, "x") @@ -177,7 +177,7 @@ def test_from_pandas_index(self): xr.testing.assert_identical(index_vars["foo"], IndexVariable("x", [1, 2, 3])) xr.testing.assert_identical(index_vars["bar"], IndexVariable("x", [4, 5, 6])) - def test_query(self): + def test_query(self) -> None: index = PandasMultiIndex( pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")), "x" ) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 6e4fd320029..533f4a0cd62 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -18,7 +18,7 @@ def set_to_zero(self, x, i): x[i] = 0 return x - def test_expanded_indexer(self): + def test_expanded_indexer(self) -> None: x = np.random.randn(10, 11, 12, 13, 14) y = np.arange(5) arr = ReturnItem() @@ -40,7 +40,7 @@ def test_expanded_indexer(self): with pytest.raises(IndexError, match=r"too many indices"): indexing.expanded_indexer(arr[1, 2, 3], 2) - def test_stacked_multiindex_min_max(self): + def test_stacked_multiindex_min_max(self) -> None: data = np.random.randn(3, 23, 4) da = DataArray( data, @@ -55,7 +55,7 @@ def test_stacked_multiindex_min_max(self): assert_array_equal(da2.loc["a", s.max()], data[2, 22, 0]) assert_array_equal(da2.loc["b", s.min()], data[0, 0, 1]) - def test_group_indexers_by_index(self): + def test_group_indexers_by_index(self) -> None: mindex = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) data = DataArray( np.zeros((4, 2, 2)), coords={"x": mindex, "y": [1, 2]}, dims=("x", "y", "z") @@ -79,8 +79,8 @@ def test_group_indexers_by_index(self): with pytest.raises(ValueError, match=r"cannot supply.*"): indexing.group_indexers_by_index(data, {"z": 1}, method="nearest") - def test_remap_label_indexers(self): - def test_indexer(data, x, expected_pos, expected_idx=None): + def test_remap_label_indexers(self) -> None: + def test_indexer(data, x, expected_pos, expected_idx=None) -> None: pos, new_idx_vars = indexing.remap_label_indexers(data, {"x": x}) idx, _ = new_idx_vars.get("x", (None, None)) if idx is not None: @@ -139,7 +139,7 @@ def test_indexer(data, x, expected_pos, expected_idx=None): pd.MultiIndex.from_product([[1, 2], [-1, -2]]), ) - def test_read_only_view(self): + def test_read_only_view(self) -> None: arr = DataArray( np.random.rand(3, 3), @@ -153,7 +153,7 @@ def test_read_only_view(self): class TestLazyArray: - def test_slice_slice(self): + def test_slice_slice(self) -> None: arr = ReturnItem() for size in [100, 99]: # We test even/odd size cases @@ -183,7 +183,7 @@ def test_slice_slice(self): actual = x[new_slice] assert_array_equal(expected, actual) - def test_lazily_indexed_array(self): + def test_lazily_indexed_array(self) -> None: original = np.random.rand(10, 20, 30) x = indexing.NumpyIndexingAdapter(original) v = Variable(["i", "j", "k"], original) @@ -228,17 +228,17 @@ def test_lazily_indexed_array(self): ([0, 3, 5], arr[:2]), ] for i, j in indexers: - expected = v[i][j] + expected_b = v[i][j] actual = v_lazy[i][j] - assert expected.shape == actual.shape - assert_array_equal(expected, actual) + assert expected_b.shape == actual.shape + assert_array_equal(expected_b, actual) # test transpose if actual.ndim > 1: order = np.random.choice(actual.ndim, actual.ndim) order = np.array(actual.dims) transposed = actual.transpose(*order) - assert_array_equal(expected.transpose(*order), transposed) + assert_array_equal(expected_b.transpose(*order), transposed) assert isinstance( actual._data, ( @@ -250,7 +250,7 @@ def test_lazily_indexed_array(self): assert isinstance(actual._data, indexing.LazilyIndexedArray) assert isinstance(actual._data.array, indexing.NumpyIndexingAdapter) - def test_vectorized_lazily_indexed_array(self): + def test_vectorized_lazily_indexed_array(self) -> None: original = np.random.rand(10, 20, 30) x = indexing.NumpyIndexingAdapter(original) v_eager = Variable(["i", "j", "k"], x) @@ -300,14 +300,14 @@ def check_indexing(v_eager, v_lazy, indexers): class TestCopyOnWriteArray: - def test_setitem(self): + def test_setitem(self) -> None: original = np.arange(10) wrapped = indexing.CopyOnWriteArray(original) wrapped[B[:]] = 0 assert_array_equal(original, np.arange(10)) assert_array_equal(wrapped, np.zeros(10)) - def test_sub_array(self): + def test_sub_array(self) -> None: original = np.arange(10) wrapped = indexing.CopyOnWriteArray(original) child = wrapped[B[:5]] @@ -317,20 +317,20 @@ def test_sub_array(self): assert_array_equal(wrapped, np.arange(10)) assert_array_equal(child, np.zeros(5)) - def test_index_scalar(self): + def test_index_scalar(self) -> None: # regression test for GH1374 x = indexing.CopyOnWriteArray(np.array(["foo", "bar"])) assert np.array(x[B[0]][B[()]]) == "foo" class TestMemoryCachedArray: - def test_wrapper(self): + def test_wrapper(self) -> None: original = indexing.LazilyIndexedArray(np.arange(10)) wrapped = indexing.MemoryCachedArray(original) assert_array_equal(wrapped, np.arange(10)) assert isinstance(wrapped.array, indexing.NumpyIndexingAdapter) - def test_sub_array(self): + def test_sub_array(self) -> None: original = indexing.LazilyIndexedArray(np.arange(10)) wrapped = indexing.MemoryCachedArray(original) child = wrapped[B[:5]] @@ -339,19 +339,19 @@ def test_sub_array(self): assert isinstance(child.array, indexing.NumpyIndexingAdapter) assert isinstance(wrapped.array, indexing.LazilyIndexedArray) - def test_setitem(self): + def test_setitem(self) -> None: original = np.arange(10) wrapped = indexing.MemoryCachedArray(original) wrapped[B[:]] = 0 assert_array_equal(original, np.zeros(10)) - def test_index_scalar(self): + def test_index_scalar(self) -> None: # regression test for GH1374 x = indexing.MemoryCachedArray(np.array(["foo", "bar"])) assert np.array(x[B[0]][B[()]]) == "foo" -def test_base_explicit_indexer(): +def test_base_explicit_indexer() -> None: with pytest.raises(TypeError): indexing.ExplicitIndexer(()) @@ -367,7 +367,7 @@ class Subclass(indexing.ExplicitIndexer): "indexer_cls", [indexing.BasicIndexer, indexing.OuterIndexer, indexing.VectorizedIndexer], ) -def test_invalid_for_all(indexer_cls): +def test_invalid_for_all(indexer_cls) -> None: with pytest.raises(TypeError): indexer_cls(None) with pytest.raises(TypeError): @@ -409,7 +409,7 @@ def check_array2d(indexer_cls): np.testing.assert_array_equal(value, array) -def test_basic_indexer(): +def test_basic_indexer() -> None: check_integer(indexing.BasicIndexer) check_slice(indexing.BasicIndexer) with pytest.raises(TypeError): @@ -418,7 +418,7 @@ def test_basic_indexer(): check_array2d(indexing.BasicIndexer) -def test_outer_indexer(): +def test_outer_indexer() -> None: check_integer(indexing.OuterIndexer) check_slice(indexing.OuterIndexer) check_array1d(indexing.OuterIndexer) @@ -426,7 +426,7 @@ def test_outer_indexer(): check_array2d(indexing.OuterIndexer) -def test_vectorized_indexer(): +def test_vectorized_indexer() -> None: with pytest.raises(TypeError): check_integer(indexing.VectorizedIndexer) check_slice(indexing.VectorizedIndexer) @@ -450,7 +450,7 @@ def setup(self): slice(None), ] - def test_arrayize_vectorized_indexer(self): + def test_arrayize_vectorized_indexer(self) -> None: for i, j, k in itertools.product(self.indexers, repeat=3): vindex = indexing.VectorizedIndexer((i, j, k)) vindex_array = indexing._arrayize_vectorized_indexer( @@ -530,7 +530,7 @@ def get_indexers(shape, mode): @pytest.mark.parametrize( "sl", [slice(1, -1, 1), slice(None, -1, 2), slice(-1, 1, -1), slice(-1, 1, -2)] ) -def test_decompose_slice(size, sl): +def test_decompose_slice(size, sl) -> None: x = np.arange(size) slice1, slice2 = indexing._decompose_slice(sl, size) expected = x[sl] @@ -562,7 +562,7 @@ def test_decompose_slice(size, sl): indexing.IndexingSupport.VECTORIZED, ], ) -def test_decompose_indexers(shape, indexer_mode, indexing_support): +def test_decompose_indexers(shape, indexer_mode, indexing_support) -> None: data = np.random.randn(*shape) indexer = get_indexers(shape, indexer_mode) @@ -580,7 +580,7 @@ def test_decompose_indexers(shape, indexer_mode, indexing_support): np.testing.assert_array_equal(expected, array) -def test_implicit_indexing_adapter(): +def test_implicit_indexing_adapter() -> None: array = np.arange(10, dtype=np.int64) implicit = indexing.ImplicitToExplicitIndexingAdapter( indexing.NumpyIndexingAdapter(array), indexing.BasicIndexer @@ -589,7 +589,7 @@ def test_implicit_indexing_adapter(): np.testing.assert_array_equal(array, implicit[:]) -def test_implicit_indexing_adapter_copy_on_write(): +def test_implicit_indexing_adapter_copy_on_write() -> None: array = np.arange(10, dtype=np.int64) implicit = indexing.ImplicitToExplicitIndexingAdapter( indexing.CopyOnWriteArray(array) @@ -597,7 +597,7 @@ def test_implicit_indexing_adapter_copy_on_write(): assert isinstance(implicit[:], indexing.ImplicitToExplicitIndexingAdapter) -def test_outer_indexer_consistency_with_broadcast_indexes_vectorized(): +def test_outer_indexer_consistency_with_broadcast_indexes_vectorized() -> None: def nonzero(x): if isinstance(x, np.ndarray) and x.dtype.kind == "b": x = x.nonzero()[0] @@ -635,7 +635,7 @@ def nonzero(x): np.testing.assert_array_equal(actual_data, expected_data) -def test_create_mask_outer_indexer(): +def test_create_mask_outer_indexer() -> None: indexer = indexing.OuterIndexer((np.array([0, -1, 2]),)) expected = np.array([False, True, False]) actual = indexing.create_mask(indexer, (5,)) @@ -647,7 +647,7 @@ def test_create_mask_outer_indexer(): np.testing.assert_array_equal(expected, actual) -def test_create_mask_vectorized_indexer(): +def test_create_mask_vectorized_indexer() -> None: indexer = indexing.VectorizedIndexer((np.array([0, -1, 2]), np.array([0, 1, -1]))) expected = np.array([False, True, True]) actual = indexing.create_mask(indexer, (5,)) @@ -661,7 +661,7 @@ def test_create_mask_vectorized_indexer(): np.testing.assert_array_equal(expected, actual) -def test_create_mask_basic_indexer(): +def test_create_mask_basic_indexer() -> None: indexer = indexing.BasicIndexer((-1,)) actual = indexing.create_mask(indexer, (3,)) np.testing.assert_array_equal(True, actual) @@ -671,7 +671,7 @@ def test_create_mask_basic_indexer(): np.testing.assert_array_equal(False, actual) -def test_create_mask_dask(): +def test_create_mask_dask() -> None: da = pytest.importorskip("dask.array") indexer = indexing.OuterIndexer((1, slice(2), np.array([0, -1, 2]))) @@ -682,21 +682,21 @@ def test_create_mask_dask(): assert actual.chunks == ((1, 1), (2, 1)) np.testing.assert_array_equal(expected, actual) - indexer = indexing.VectorizedIndexer( + indexer_vec = indexing.VectorizedIndexer( (np.array([0, -1, 2]), slice(None), np.array([0, 1, -1])) ) expected = np.array([[False, True, True]] * 2).T actual = indexing.create_mask( - indexer, (5, 2), da.empty((3, 2), chunks=((3,), (2,))) + indexer_vec, (5, 2), da.empty((3, 2), chunks=((3,), (2,))) ) assert isinstance(actual, da.Array) np.testing.assert_array_equal(expected, actual) with pytest.raises(ValueError): - indexing.create_mask(indexer, (5, 2), da.empty((5,), chunks=(1,))) + indexing.create_mask(indexer_vec, (5, 2), da.empty((5,), chunks=(1,))) -def test_create_mask_error(): +def test_create_mask_error() -> None: with pytest.raises(TypeError, match=r"unexpected key type"): indexing.create_mask((1, 2), (3, 4)) @@ -713,12 +713,12 @@ def test_create_mask_error(): (np.array([0, -1, -1, -1, 1]), np.array([0, 0, 0, 0, 1])), ], ) -def test_posify_mask_subindexer(indices, expected): +def test_posify_mask_subindexer(indices, expected) -> None: actual = indexing._posify_mask_subindexer(indices) np.testing.assert_array_equal(expected, actual) -def test_indexing_1d_object_array(): +def test_indexing_1d_object_array() -> None: items = (np.arange(3), np.arange(6)) arr = DataArray(np.array(items, dtype=object)) diff --git a/xarray/tests/test_nputils.py b/xarray/tests/test_nputils.py index 3c9c92ae2ba..ba8e70ea514 100644 --- a/xarray/tests/test_nputils.py +++ b/xarray/tests/test_nputils.py @@ -4,13 +4,13 @@ from xarray.core.nputils import NumpyVIndexAdapter, _is_contiguous -def test_is_contiguous(): +def test_is_contiguous() -> None: assert _is_contiguous([1]) assert _is_contiguous([1, 2, 3]) assert not _is_contiguous([1, 3]) -def test_vindex(): +def test_vindex() -> None: x = np.arange(3 * 4 * 5).reshape((3, 4, 5)) vindex = NumpyVIndexAdapter(x) diff --git a/xarray/tests/test_options.py b/xarray/tests/test_options.py index 19f74476ced..be71500dc0a 100644 --- a/xarray/tests/test_options.py +++ b/xarray/tests/test_options.py @@ -7,12 +7,12 @@ from xarray.tests.test_dataset import create_test_data -def test_invalid_option_raises(): +def test_invalid_option_raises() -> None: with pytest.raises(ValueError): xarray.set_options(not_a_valid_options=True) -def test_display_width(): +def test_display_width() -> None: with pytest.raises(ValueError): xarray.set_options(display_width=0) with pytest.raises(ValueError): @@ -21,14 +21,14 @@ def test_display_width(): xarray.set_options(display_width=3.5) -def test_arithmetic_join(): +def test_arithmetic_join() -> None: with pytest.raises(ValueError): xarray.set_options(arithmetic_join="invalid") with xarray.set_options(arithmetic_join="exact"): assert OPTIONS["arithmetic_join"] == "exact" -def test_enable_cftimeindex(): +def test_enable_cftimeindex() -> None: with pytest.raises(ValueError): xarray.set_options(enable_cftimeindex=None) with pytest.warns(FutureWarning, match="no-op"): @@ -36,7 +36,7 @@ def test_enable_cftimeindex(): assert OPTIONS["enable_cftimeindex"] -def test_file_cache_maxsize(): +def test_file_cache_maxsize() -> None: with pytest.raises(ValueError): xarray.set_options(file_cache_maxsize=0) original_size = FILE_CACHE.maxsize @@ -45,7 +45,7 @@ def test_file_cache_maxsize(): assert FILE_CACHE.maxsize == original_size -def test_keep_attrs(): +def test_keep_attrs() -> None: with pytest.raises(ValueError): xarray.set_options(keep_attrs="invalid_str") with xarray.set_options(keep_attrs=True): @@ -57,7 +57,7 @@ def test_keep_attrs(): assert not _get_keep_attrs(default=False) -def test_nested_options(): +def test_nested_options() -> None: original = OPTIONS["display_width"] with xarray.set_options(display_width=1): assert OPTIONS["display_width"] == 1 @@ -67,7 +67,7 @@ def test_nested_options(): assert OPTIONS["display_width"] == original -def test_display_style(): +def test_display_style() -> None: original = "html" assert OPTIONS["display_style"] == original with pytest.raises(ValueError): @@ -90,7 +90,7 @@ def create_test_dataarray_attrs(seed=0, var="var1"): class TestAttrRetention: - def test_dataset_attr_retention(self): + def test_dataset_attr_retention(self) -> None: # Use .mean() for all tests: a typical reduction operation ds = create_test_dataset_attrs() original_attrs = ds.attrs @@ -110,7 +110,7 @@ def test_dataset_attr_retention(self): result = ds.mean() assert result.attrs == {} - def test_dataarray_attr_retention(self): + def test_dataarray_attr_retention(self) -> None: # Use .mean() for all tests: a typical reduction operation da = create_test_dataarray_attrs() original_attrs = da.attrs @@ -130,7 +130,7 @@ def test_dataarray_attr_retention(self): result = da.mean() assert result.attrs == {} - def test_groupby_attr_retention(self): + def test_groupby_attr_retention(self) -> None: da = xarray.DataArray([1, 2, 3], [("x", [1, 1, 2])]) da.attrs = {"attr1": 5, "attr2": "history", "attr3": {"nested": "more_info"}} original_attrs = da.attrs @@ -151,7 +151,7 @@ def test_groupby_attr_retention(self): result = da.groupby("x").sum() assert result.attrs == {} - def test_concat_attr_retention(self): + def test_concat_attr_retention(self) -> None: ds1 = create_test_dataset_attrs() ds2 = create_test_dataset_attrs() ds2.attrs = {"wrong": "attributes"} @@ -164,7 +164,7 @@ def test_concat_attr_retention(self): assert result.attrs == original_attrs @pytest.mark.xfail - def test_merge_attr_retention(self): + def test_merge_attr_retention(self) -> None: da1 = create_test_dataarray_attrs(var="var1") da2 = create_test_dataarray_attrs(var="var2") da2.attrs = {"wrong": "attributes"} @@ -175,7 +175,7 @@ def test_merge_attr_retention(self): result = merge([da1, da2]) assert result.attrs == original_attrs - def test_display_style_text(self): + def test_display_style_text(self) -> None: ds = create_test_dataset_attrs() with xarray.set_options(display_style="text"): text = ds._repr_html_() @@ -183,21 +183,21 @@ def test_display_style_text(self): assert "'nested'" in text assert "<xarray.Dataset>" in text - def test_display_style_html(self): + def test_display_style_html(self) -> None: ds = create_test_dataset_attrs() with xarray.set_options(display_style="html"): html = ds._repr_html_() assert html.startswith("
    ") assert "'nested'" in html - def test_display_dataarray_style_text(self): + def test_display_dataarray_style_text(self) -> None: da = create_test_dataarray_attrs() with xarray.set_options(display_style="text"): text = da._repr_html_() assert text.startswith("
    ")
                 assert "<xarray.DataArray 'var1'" in text
     
    -    def test_display_dataarray_style_html(self):
    +    def test_display_dataarray_style_html(self) -> None:
             da = create_test_dataarray_attrs()
             with xarray.set_options(display_style="html"):
                 html = da._repr_html_()
    diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py
    index b7a5f9405d1..7f77a677d6d 100644
    --- a/xarray/tests/test_plugins.py
    +++ b/xarray/tests/test_plugins.py
    @@ -39,13 +39,13 @@ def dummy_duplicated_entrypoints():
     
     
     @pytest.mark.filterwarnings("ignore:Found")
    -def test_remove_duplicates(dummy_duplicated_entrypoints):
    +def test_remove_duplicates(dummy_duplicated_entrypoints) -> None:
         with pytest.warns(RuntimeWarning):
             entrypoints = plugins.remove_duplicates(dummy_duplicated_entrypoints)
         assert len(entrypoints) == 2
     
     
    -def test_broken_plugin():
    +def test_broken_plugin() -> None:
         broken_backend = pkg_resources.EntryPoint.parse(
             "broken_backend = xarray.tests.test_plugins:backend_1"
         )
    @@ -56,7 +56,7 @@ def test_broken_plugin():
         assert "Engine 'broken_backend'" in message
     
     
    -def test_remove_duplicates_warnings(dummy_duplicated_entrypoints):
    +def test_remove_duplicates_warnings(dummy_duplicated_entrypoints) -> None:
     
         with pytest.warns(RuntimeWarning) as record:
             _ = plugins.remove_duplicates(dummy_duplicated_entrypoints)
    @@ -69,7 +69,7 @@ def test_remove_duplicates_warnings(dummy_duplicated_entrypoints):
     
     
     @mock.patch("pkg_resources.EntryPoint.load", mock.MagicMock(return_value=None))
    -def test_backends_dict_from_pkg():
    +def test_backends_dict_from_pkg() -> None:
         specs = [
             "engine1 = xarray.tests.test_plugins:backend_1",
             "engine2 = xarray.tests.test_plugins:backend_2",
    @@ -80,7 +80,7 @@ def test_backends_dict_from_pkg():
         assert engines.keys() == set(("engine1", "engine2"))
     
     
    -def test_set_missing_parameters():
    +def test_set_missing_parameters() -> None:
         backend_1 = DummyBackendEntrypoint1
         backend_2 = DummyBackendEntrypoint2
         backend_2.open_dataset_parameters = ("filename_or_obj",)
    @@ -96,28 +96,28 @@ def test_set_missing_parameters():
         plugins.set_missing_parameters({"engine": backend})
         assert backend.open_dataset_parameters == ("filename_or_obj", "decoder")
     
    -    backend = DummyBackendEntrypointArgs()
    -    backend.open_dataset_parameters = ("filename_or_obj", "decoder")
    -    plugins.set_missing_parameters({"engine": backend})
    -    assert backend.open_dataset_parameters == ("filename_or_obj", "decoder")
    +    backend_args = DummyBackendEntrypointArgs()
    +    backend_args.open_dataset_parameters = ("filename_or_obj", "decoder")
    +    plugins.set_missing_parameters({"engine": backend_args})
    +    assert backend_args.open_dataset_parameters == ("filename_or_obj", "decoder")
     
     
    -def test_set_missing_parameters_raise_error():
    +def test_set_missing_parameters_raise_error() -> None:
     
         backend = DummyBackendEntrypointKwargs()
         with pytest.raises(TypeError):
             plugins.set_missing_parameters({"engine": backend})
     
    -    backend = DummyBackendEntrypointArgs()
    +    backend_args = DummyBackendEntrypointArgs()
         with pytest.raises(TypeError):
    -        plugins.set_missing_parameters({"engine": backend})
    +        plugins.set_missing_parameters({"engine": backend_args})
     
     
     @mock.patch(
         "pkg_resources.EntryPoint.load",
         mock.MagicMock(return_value=DummyBackendEntrypoint1),
     )
    -def test_build_engines():
    +def test_build_engines() -> None:
         dummy_pkg_entrypoint = pkg_resources.EntryPoint.parse(
             "cfgrib = xarray.tests.test_plugins:backend_1"
         )
    @@ -134,7 +134,7 @@ def test_build_engines():
         "pkg_resources.EntryPoint.load",
         mock.MagicMock(return_value=DummyBackendEntrypoint1),
     )
    -def test_build_engines_sorted():
    +def test_build_engines_sorted() -> None:
         dummy_pkg_entrypoints = [
             pkg_resources.EntryPoint.parse(
                 "dummy2 = xarray.tests.test_plugins:backend_1",
    @@ -163,7 +163,7 @@ def test_build_engines_sorted():
         "xarray.backends.plugins.list_engines",
         mock.MagicMock(return_value={"dummy": DummyBackendEntrypointArgs()}),
     )
    -def test_no_matching_engine_found():
    +def test_no_matching_engine_found() -> None:
         with pytest.raises(ValueError, match=r"did not find a match in any"):
             plugins.guess_engine("not-valid")
     
    @@ -175,7 +175,7 @@ def test_no_matching_engine_found():
         "xarray.backends.plugins.list_engines",
         mock.MagicMock(return_value={}),
     )
    -def test_engines_not_installed():
    +def test_engines_not_installed() -> None:
         with pytest.raises(ValueError, match=r"xarray is unable to open"):
             plugins.guess_engine("not-valid")
     
    diff --git a/xarray/tests/test_print_versions.py b/xarray/tests/test_print_versions.py
    index 01c30e5e301..42ebe5b2ac2 100644
    --- a/xarray/tests/test_print_versions.py
    +++ b/xarray/tests/test_print_versions.py
    @@ -3,7 +3,7 @@
     import xarray
     
     
    -def test_show_versions():
    +def test_show_versions() -> None:
         f = io.StringIO()
         xarray.show_versions(file=f)
         assert "INSTALLED VERSIONS" in f.getvalue()
    diff --git a/xarray/tests/test_testing.py b/xarray/tests/test_testing.py
    index dc1db4dc8d7..2bde7529d1e 100644
    --- a/xarray/tests/test_testing.py
    +++ b/xarray/tests/test_testing.py
    @@ -29,7 +29,7 @@ def quantity(x):
         has_pint = False
     
     
    -def test_allclose_regression():
    +def test_allclose_regression() -> None:
         x = xr.DataArray(1.01)
         y = xr.DataArray(1.02)
         xr.testing.assert_allclose(x, y, atol=0.01)
    @@ -53,7 +53,7 @@ def test_allclose_regression():
             ),
         ),
     )
    -def test_assert_allclose(obj1, obj2):
    +def test_assert_allclose(obj1, obj2) -> None:
         with pytest.raises(AssertionError):
             xr.testing.assert_allclose(obj1, obj2)
     
    @@ -83,7 +83,7 @@ def test_assert_allclose(obj1, obj2):
             pytest.param(0.0, [1e-17, 2], id="first scalar"),
         ),
     )
    -def test_assert_duckarray_equal_failing(duckarray, obj1, obj2):
    +def test_assert_duckarray_equal_failing(duckarray, obj1, obj2) -> None:
         # TODO: actually check the repr
         a = duckarray(obj1)
         b = duckarray(obj2)
    @@ -119,7 +119,7 @@ def test_assert_duckarray_equal_failing(duckarray, obj1, obj2):
             pytest.param(0.0, [0, 0], id="first scalar"),
         ),
     )
    -def test_assert_duckarray_equal(duckarray, obj1, obj2):
    +def test_assert_duckarray_equal(duckarray, obj1, obj2) -> None:
         a = duckarray(obj1)
         b = duckarray(obj2)
     
    @@ -136,7 +136,7 @@ def test_assert_duckarray_equal(duckarray, obj1, obj2):
             "assert_duckarray_allclose",
         ],
     )
    -def test_ensure_warnings_not_elevated(func):
    +def test_ensure_warnings_not_elevated(func) -> None:
         # make sure warnings are not elevated to errors in the assertion functions
         # e.g. by @pytest.mark.filterwarnings("error")
         # see https://github.com/pydata/xarray/pull/4760#issuecomment-774101639
    diff --git a/xarray/tests/test_tutorial.py b/xarray/tests/test_tutorial.py
    index 225fda08f68..411ad52368d 100644
    --- a/xarray/tests/test_tutorial.py
    +++ b/xarray/tests/test_tutorial.py
    @@ -11,13 +11,15 @@ class TestLoadDataset:
         def setUp(self):
             self.testfile = "tiny"
     
    -    def test_download_from_github(self, tmp_path):
    +    def test_download_from_github(self, tmp_path) -> None:
             cache_dir = tmp_path / tutorial._default_cache_dir_name
             ds = tutorial.open_dataset(self.testfile, cache_dir=cache_dir).load()
             tiny = DataArray(range(5), name="tiny").to_dataset()
             assert_identical(ds, tiny)
     
    -    def test_download_from_github_load_without_cache(self, tmp_path, monkeypatch):
    +    def test_download_from_github_load_without_cache(
    +        self, tmp_path, monkeypatch
    +    ) -> None:
             cache_dir = tmp_path / tutorial._default_cache_dir_name
     
             ds_nocache = tutorial.open_dataset(
    
    From 059c91abed0ae153b874fb764b4514dede9c455f Mon Sep 17 00:00:00 2001
    From: Deepak Cherian 
    Date: Mon, 23 Aug 2021 10:42:09 -0600
    Subject: [PATCH 18/20] Add .git-blame-ignore-revs (#5708)
    
    ---
     .git-blame-ignore-revs | 5 +++++
     1 file changed, 5 insertions(+)
     create mode 100644 .git-blame-ignore-revs
    
    diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
    new file mode 100644
    index 00000000000..465572c9246
    --- /dev/null
    +++ b/.git-blame-ignore-revs
    @@ -0,0 +1,5 @@
    +# black PR 3142
    +d089df385e737f71067309ff7abae15994d581ec
    +
    +# isort PR 1924
    +0e73e240107caee3ffd1a1149f0150c390d43251
    
    From 808b2e9eec1ec48d43cf9311dee27cd5550b8fd2 Mon Sep 17 00:00:00 2001
    From: Giacomo Caria <44147817+gcaria@users.noreply.github.com>
    Date: Mon, 23 Aug 2021 19:00:39 +0200
    Subject: [PATCH 19/20] Set coord name concat when `concat`ing along a
     DataArray (#5611)
    
    ---
     xarray/core/concat.py       | 2 ++
     xarray/tests/test_concat.py | 9 +++++++++
     2 files changed, 11 insertions(+)
    
    diff --git a/xarray/core/concat.py b/xarray/core/concat.py
    index 7a15685fd56..e26c1464f2d 100644
    --- a/xarray/core/concat.py
    +++ b/xarray/core/concat.py
    @@ -264,6 +264,8 @@ def _calc_concat_dim_coord(dim):
             (dim,) = coord.dims
         else:
             coord = dim
    +        if coord.name is None:
    +            coord.name = dim.dims[0]
             (dim,) = coord.dims
         return dim, coord
     
    diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py
    index 36ef0237b27..e049f843bed 100644
    --- a/xarray/tests/test_concat.py
    +++ b/xarray/tests/test_concat.py
    @@ -678,6 +678,15 @@ def test_concat_str_dtype(self, dtype, dim):
     
             assert np.issubdtype(actual.x2.dtype, dtype)
     
    +    def test_concat_coord_name(self):
    +
    +        da = DataArray([0], dims="a")
    +        da_concat = concat([da, da], dim=DataArray([0, 1], dims="b"))
    +        assert list(da_concat.coords) == ["b"]
    +
    +        da_concat_std = concat([da, da], dim=DataArray([0, 1]))
    +        assert list(da_concat_std.coords) == ["dim_0"]
    +
     
     @pytest.mark.parametrize("attr1", ({"a": {"meta": [10, 20, 30]}}, {"a": [1, 2, 3]}, {}))
     @pytest.mark.parametrize("attr2", ({"a": [1, 2, 3]}, {}))
    
    From a6b44d72010c5311fad4f60194b527194e26094b Mon Sep 17 00:00:00 2001
    From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
    Date: Mon, 23 Aug 2021 13:16:20 -0600
    Subject: [PATCH 20/20] Bump actions/github-script from 4.0.2 to 4.1 (#5730)
    
    Bumps [actions/github-script](https://github.com/actions/github-script) from 4.0.2 to 4.1.
    - [Release notes](https://github.com/actions/github-script/releases)
    - [Commits](https://github.com/actions/github-script/compare/v4.0.2...v4.1)
    
    ---
    updated-dependencies:
    - dependency-name: actions/github-script
      dependency-type: direct:production
      update-type: version-update:semver-minor
    ...
    
    Signed-off-by: dependabot[bot] 
    
    Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
    ---
     .github/workflows/upstream-dev-ci.yaml | 2 +-
     1 file changed, 1 insertion(+), 1 deletion(-)
    
    diff --git a/.github/workflows/upstream-dev-ci.yaml b/.github/workflows/upstream-dev-ci.yaml
    index 9b1664a0292..15ff3f7bda6 100644
    --- a/.github/workflows/upstream-dev-ci.yaml
    +++ b/.github/workflows/upstream-dev-ci.yaml
    @@ -122,7 +122,7 @@ jobs:
               shopt -s globstar
               python .github/workflows/parse_logs.py logs/**/*-log
           - name: Report failures
    -        uses: actions/github-script@v4.0.2
    +        uses: actions/github-script@v4.1
             with:
               github-token: ${{ secrets.GITHUB_TOKEN }}
               script: |