diff --git a/.binder/environment.yml b/.binder/environment.yml index 6fd5829c5e6..6caea42df87 100644 --- a/.binder/environment.yml +++ b/.binder/environment.yml @@ -2,7 +2,7 @@ name: xarray-examples channels: - conda-forge dependencies: - - python=3.8 + - python=3.9 - boto3 - bottleneck - cartopy @@ -26,6 +26,7 @@ dependencies: - pandas - pint - pip + - pooch - pydap - pynio - rasterio diff --git a/.github/workflows/cancel-duplicate-runs.yaml b/.github/workflows/cancel-duplicate-runs.yaml index 46637bdc112..9f74360b034 100644 --- a/.github/workflows/cancel-duplicate-runs.yaml +++ b/.github/workflows/cancel-duplicate-runs.yaml @@ -10,6 +10,6 @@ jobs: runs-on: ubuntu-latest if: github.repository == 'pydata/xarray' steps: - - uses: styfle/cancel-workflow-action@0.9.0 + - uses: styfle/cancel-workflow-action@0.9.1 with: workflow_id: ${{ github.event.workflow.id }} diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index 2b9a6405f21..ed731b25f76 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -103,7 +103,7 @@ jobs: $PYTEST_EXTRA_FLAGS - name: Upload code coverage to Codecov - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v2.0.2 with: file: ./coverage.xml flags: unittests,${{ matrix.env }} diff --git a/.github/workflows/ci-pre-commit-autoupdate.yaml b/.github/workflows/ci-pre-commit-autoupdate.yaml index 8ba7ac14ef1..b10a541197e 100644 --- a/.github/workflows/ci-pre-commit-autoupdate.yaml +++ b/.github/workflows/ci-pre-commit-autoupdate.yaml @@ -35,7 +35,6 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} EXECUTE_COMMANDS: | python -m pre_commit autoupdate - python .github/workflows/sync_linter_versions.py .pre-commit-config.yaml ci/requirements/mypy_only python -m pre_commit run --all-files COMMIT_MESSAGE: 'pre-commit: autoupdate hook versions' COMMIT_NAME: 'github-actions[bot]' diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 3918f92574d..22a05eb1fc0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -100,7 +100,7 @@ jobs: path: pytest.xml - name: Upload code coverage to Codecov - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v2.0.2 with: file: ./coverage.xml flags: unittests diff --git a/.github/workflows/sync_linter_versions.py b/.github/workflows/sync_linter_versions.py deleted file mode 100755 index cb0b1355c71..00000000000 --- a/.github/workflows/sync_linter_versions.py +++ /dev/null @@ -1,76 +0,0 @@ -#!/usr/bin/env python -import argparse -import itertools -import pathlib -import re - -import yaml -from packaging import version -from packaging.requirements import Requirement - -operator_re = re.compile("=+") - - -def extract_versions(config): - repos = config.get("repos") - if repos is None: - raise ValueError("invalid pre-commit configuration") - - extracted_versions = ( - ((hook["id"], version.parse(repo["rev"])) for hook in repo["hooks"]) - for repo in repos - ) - return dict(itertools.chain.from_iterable(extracted_versions)) - - -def update_requirement(line, new_versions): - # convert to pep-508 compatible - preprocessed = operator_re.sub("==", line) - requirement = Requirement(preprocessed) - - specifier, *_ = requirement.specifier - old_version = specifier.version - new_version = new_versions.get(requirement.name, old_version) - - new_line = f"{requirement.name}={new_version}" - - return new_line - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--dry", action="store_true") - parser.add_argument( - metavar="pre-commit-config", dest="pre_commit_config", type=pathlib.Path - ) - parser.add_argument("requirements", type=pathlib.Path) - args = parser.parse_args() - - with args.pre_commit_config.open() as f: - config = yaml.safe_load(f) - - versions = extract_versions(config) - mypy_version = versions["mypy"] - - requirements_text = args.requirements.read_text() - requirements = requirements_text.split("\n") - new_requirements = [ - update_requirement(line, versions) - if line and not line.startswith("# ") - else line - for line in requirements - ] - new_requirements_text = "\n".join(new_requirements) - - if args.dry: - separator = "\n" + "—" * 80 + "\n" - print( - "contents of the old requirements file:", - requirements_text, - "contents of the new requirements file:", - new_requirements_text, - sep=separator, - end=separator, - ) - else: - args.requirements.write_text(new_requirements_text) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 42d1fb0a4a5..53525d0def9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,16 +8,16 @@ repos: - id: check-yaml # isort should run before black as black sometimes tweaks the isort output - repo: https://github.com/PyCQA/isort - rev: 5.9.1 + rev: 5.9.3 hooks: - id: isort # https://github.com/python/black#version-control-integration - repo: https://github.com/psf/black - rev: 21.6b0 + rev: 21.7b0 hooks: - id: black - repo: https://github.com/keewis/blackdoc - rev: v0.3.3 + rev: v0.3.4 hooks: - id: blackdoc - repo: https://gitlab.com/pycqa/flake8 @@ -30,7 +30,6 @@ repos: # - id: velin # args: ["--write", "--compact"] - repo: https://github.com/pre-commit/mirrors-mypy - # version must correspond to the one in .github/workflows/ci-additional.yaml rev: v0.910 hooks: - id: mypy @@ -44,6 +43,7 @@ repos: types-pytz, # Dependencies that are typed numpy, + typing-extensions==3.10.0.0, ] # run this occasionally, ref discussion https://github.com/pydata/xarray/pull/3194 # - repo: https://github.com/asottile/pyupgrade diff --git a/asv_bench/benchmarks/repr.py b/asv_bench/benchmarks/repr.py index 617e9313fd1..405f6cd0530 100644 --- a/asv_bench/benchmarks/repr.py +++ b/asv_bench/benchmarks/repr.py @@ -1,8 +1,30 @@ +import numpy as np import pandas as pd import xarray as xr +class Repr: + def setup(self): + a = np.arange(0, 100) + data_vars = dict() + for i in a: + data_vars[f"long_variable_name_{i}"] = xr.DataArray( + name=f"long_variable_name_{i}", + data=np.arange(0, 20), + dims=[f"long_coord_name_{i}_x"], + coords={f"long_coord_name_{i}_x": np.arange(0, 20) * 2}, + ) + self.ds = xr.Dataset(data_vars) + self.ds.attrs = {f"attr_{k}": 2 for k in a} + + def time_repr(self): + repr(self.ds) + + def time_repr_html(self): + self.ds._repr_html_() + + class ReprMultiIndex: def setup(self): index = pd.MultiIndex.from_product( diff --git a/ci/install-upstream-wheels.sh b/ci/install-upstream-wheels.sh index 073b28b8cfb..92a0f8fc7e7 100755 --- a/ci/install-upstream-wheels.sh +++ b/ci/install-upstream-wheels.sh @@ -7,6 +7,7 @@ conda uninstall -y --force \ matplotlib \ dask \ distributed \ + fsspec \ zarr \ cftime \ rasterio \ @@ -40,4 +41,5 @@ python -m pip install \ git+https://github.com/mapbox/rasterio \ git+https://github.com/hgrecco/pint \ git+https://github.com/pydata/bottleneck \ - git+https://github.com/pydata/sparse + git+https://github.com/pydata/sparse \ + git+https://github.com/intake/filesystem_spec diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index fc32d35837b..78ead40d5a2 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -10,6 +10,7 @@ dependencies: - cftime - dask - distributed + - fsspec!=2021.7.0 - h5netcdf - h5py - hdf5 diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index c8afc3c21bb..f64ca3677cc 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -12,6 +12,7 @@ dependencies: - cftime - dask - distributed + - fsspec!=2021.7.0 - h5netcdf - h5py - hdf5 diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 076b0eb452a..fc27d9c3fe8 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -54,7 +54,6 @@ core.rolling.DatasetCoarsen.var core.rolling.DatasetCoarsen.boundary core.rolling.DatasetCoarsen.coord_func - core.rolling.DatasetCoarsen.keep_attrs core.rolling.DatasetCoarsen.obj core.rolling.DatasetCoarsen.side core.rolling.DatasetCoarsen.trim_excess @@ -120,7 +119,6 @@ core.rolling.DatasetRolling.var core.rolling.DatasetRolling.center core.rolling.DatasetRolling.dim - core.rolling.DatasetRolling.keep_attrs core.rolling.DatasetRolling.min_periods core.rolling.DatasetRolling.obj core.rolling.DatasetRolling.rollings @@ -199,7 +197,6 @@ core.rolling.DataArrayCoarsen.var core.rolling.DataArrayCoarsen.boundary core.rolling.DataArrayCoarsen.coord_func - core.rolling.DataArrayCoarsen.keep_attrs core.rolling.DataArrayCoarsen.obj core.rolling.DataArrayCoarsen.side core.rolling.DataArrayCoarsen.trim_excess @@ -263,7 +260,6 @@ core.rolling.DataArrayRolling.var core.rolling.DataArrayRolling.center core.rolling.DataArrayRolling.dim - core.rolling.DataArrayRolling.keep_attrs core.rolling.DataArrayRolling.min_periods core.rolling.DataArrayRolling.obj core.rolling.DataArrayRolling.window diff --git a/doc/api.rst b/doc/api.rst index bb3a99bfbb0..fb2296d1226 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -24,7 +24,6 @@ Top-level functions combine_by_coords combine_nested where - set_options infer_freq full_like zeros_like @@ -686,6 +685,7 @@ Dataset methods open_zarr Dataset.to_netcdf Dataset.to_pandas + Dataset.as_numpy Dataset.to_zarr save_mfdataset Dataset.to_array @@ -716,6 +716,8 @@ DataArray methods DataArray.to_pandas DataArray.to_series DataArray.to_dataframe + DataArray.to_numpy + DataArray.as_numpy DataArray.to_index DataArray.to_masked_array DataArray.to_cdms2 diff --git a/doc/conf.py b/doc/conf.py index f6f7abd61b2..0a6d1504161 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -313,7 +313,7 @@ "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), "iris": ("https://scitools-iris.readthedocs.io/en/latest", None), "numpy": ("https://numpy.org/doc/stable", None), - "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), + "scipy": ("https://docs.scipy.org/doc/scipy", None), "numba": ("https://numba.pydata.org/numba-doc/latest", None), "matplotlib": ("https://matplotlib.org/stable/", None), "dask": ("https://docs.dask.org/en/latest", None), diff --git a/doc/getting-started-guide/installing.rst b/doc/getting-started-guide/installing.rst index f3d3c0f1902..506236f3b9a 100644 --- a/doc/getting-started-guide/installing.rst +++ b/doc/getting-started-guide/installing.rst @@ -8,7 +8,6 @@ Required dependencies - Python (3.7 or later) - setuptools (40.4 or later) -- typing-extensions (3.10 or later) - `numpy `__ (1.17 or later) - `pandas `__ (1.0 or later) @@ -96,7 +95,7 @@ dependencies: - **setuptools:** 42 months (but no older than 40.4) - **numpy:** 18 months (`NEP-29 `_) -- **dask and dask.distributed:** 12 months (but no older than 2.9) +- **dask and dask.distributed:** 12 months - **sparse, pint** and other libraries that rely on `NEP-18 `_ for integration: very latest available versions only, until the technology will have diff --git a/doc/whats-new.rst b/doc/whats-new.rst index cd65e0dbe35..3dad685aaf7 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -14,13 +14,64 @@ What's New np.random.seed(123456) -.. _whats-new.0.18.3: -v0.18.3 (unreleased) +.. _whats-new.0.19.1: + +v0.19.1 (unreleased) --------------------- New Features ~~~~~~~~~~~~ + + +Breaking changes +~~~~~~~~~~~~~~~~ + + +Deprecations +~~~~~~~~~~~~ + + +Bug fixes +~~~~~~~~~ + + +Documentation +~~~~~~~~~~~~~ + + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Explicit indexes refactor: avoid ``len(index)`` in ``map_blocks`` (:pull:`5670`). + By `Deepak Cherian `_. +- Explicit indexes refactor: decouple ``xarray.Index``` from ``xarray.Variable`` (:pull:`5636`). + By `Benoit Bovy `_. +- Improve the performance of reprs for large datasets or dataarrays. (:pull:`5661`) + By `Jimmy Westling `_. + +.. _whats-new.0.19.0: + +v0.19.0 (23 July 2021) +---------------------- + +This release brings improvements to plotting of categorical data, the ability to specify how attributes +are combined in xarray operations, a new high-level :py:func:`unify_chunks` function, as well as various +deprecations, bug fixes, and minor improvements. + + +Many thanks to the 29 contributors to this release!: + +Andrew Williams, Augustus, Aureliana Barghini, Benoit Bovy, crusaderky, Deepak Cherian, ellesmith88, +Elliott Sales de Andrade, Giacomo Caria, github-actions[bot], Illviljan, Joeperdefloep, joooeey, Julia Kent, +Julius Busecke, keewis, Mathias Hauser, Matthias Göbel, Mattia Almansi, Maximilian Roos, Peter Andreas Entschev, +Ray Bell, Sander, Santiago Soler, Sebastian, Spencer Clark, Stephan Hoyer, Thomas Hirtz, Thomas Nicholas. + +New Features +~~~~~~~~~~~~ +- Allow passing argument ``missing_dims`` to :py:meth:`Variable.transpose` and :py:meth:`Dataset.transpose` + (:issue:`5550`, :pull:`5586`) + By `Giacomo Caria `_. - Allow passing a dictionary as coords to a :py:class:`DataArray` (:issue:`5527`, reverts :pull:`1539`, which had deprecated this due to python's inconsistent ordering in earlier versions). By `Sander van Rijn `_. @@ -53,6 +104,12 @@ New Features - Allow removal of the coordinate attribute ``coordinates`` on variables by setting ``.attrs['coordinates']= None`` (:issue:`5510`). By `Elle Smith `_. +- Added ``**kwargs`` argument to :py:meth:`open_rasterio` to access overviews (:issue:`3269`). + By `Pushkar Kopparla `_. +- Added :py:meth:`DataArray.to_numpy`, :py:meth:`DataArray.as_numpy`, and :py:meth:`Dataset.as_numpy`. (:pull:`5568`). + By `Tom Nicholas `_. +- Units in plot labels are now automatically inferred from wrapped :py:meth:`pint.Quantity` arrays. (:pull:`5561`). + By `Tom Nicholas `_. Breaking changes ~~~~~~~~~~~~~~~~ @@ -62,10 +119,18 @@ Breaking changes pre-existing array values. This is a safer default than the prior ``mode="a"``, and allows for higher performance writes (:pull:`5252`). By `Stephan Hoyer `_. +- The main parameter to :py:func:`combine_by_coords` is renamed to `data_objects` instead + of `datasets` so anyone calling this method using a named parameter will need to update + the name accordingly (:issue:`3248`, :pull:`4696`). + By `Augustus Ijams `_. Deprecations ~~~~~~~~~~~~ +- Removed the deprecated ``dim`` kwarg to :py:func:`DataArray.integrate` (:pull:`5630`) +- Removed the deprecated ``keep_attrs`` kwarg to :py:func:`DataArray.rolling` (:pull:`5630`) +- Removed the deprecated ``keep_attrs`` kwarg to :py:func:`DataArray.coarsen` (:pull:`5630`) +- Completed deprecation of passing an ``xarray.DataArray`` to :py:func:`Variable` - will now raise a ``TypeError`` (:pull:`5630`) Bug fixes ~~~~~~~~~ @@ -85,10 +150,9 @@ Bug fixes - Plotting a pcolormesh with ``xscale="log"`` and/or ``yscale="log"`` works as expected after improving the way the interval breaks are generated (:issue:`5333`). By `Santiago Soler `_ - - -Documentation -~~~~~~~~~~~~~ +- :py:func:`combine_by_coords` can now handle combining a list of unnamed + ``DataArray`` as input (:issue:`3248`, :pull:`4696`). + By `Augustus Ijams `_. Internal Changes @@ -99,7 +163,6 @@ Internal Changes - Publish test results & timings on each PR. (:pull:`5537`) By `Maximilian Roos `_. - - Explicit indexes refactor: add a ``xarray.Index.query()`` method in which one may eventually provide a custom implementation of label-based data selection (not ready yet for public use). Also refactor the internal, @@ -144,22 +207,9 @@ New Features - Raise more informative error when decoding time variables with invalid reference dates. (:issue:`5199`, :pull:`5288`). By `Giacomo Caria `_. -Breaking changes -~~~~~~~~~~~~~~~~ -- The main parameter to :py:func:`combine_by_coords` is renamed to `data_objects` instead - of `datasets` so anyone calling this method using a named parameter will need to update - the name accordingly (:issue:`3248`, :pull:`4696`). - By `Augustus Ijams `_. - -Deprecations -~~~~~~~~~~~~ - Bug fixes ~~~~~~~~~ -- :py:func:`combine_by_coords` can now handle combining a list of unnamed - ``DataArray`` as input (:issue:`3248`, :pull:`4696`). - By `Augustus Ijams `_. - Opening netCDF files from a path that doesn't end in ``.nc`` without supplying an explicit ``engine`` works again (:issue:`5295`), fixing a bug introduced in 0.18.0. diff --git a/setup.cfg b/setup.cfg index 5a6e0b3435d..c44d207bf0f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -67,6 +67,7 @@ classifiers = Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 Topic :: Scientific/Engineering [options] @@ -78,7 +79,6 @@ install_requires = numpy >= 1.17 pandas >= 1.0 setuptools >= 40.4 # For pkg_resources - typing-extensions >= 3.10 # Backported type hints [options.extras_require] io = diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 49a5a9ec7ae..1891fac8668 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -162,7 +162,14 @@ def default(s): return parsed_meta -def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, lock=None): +def open_rasterio( + filename, + parse_coordinates=None, + chunks=None, + cache=None, + lock=None, + **kwargs, +): """Open a file with rasterio (experimental). This should work with any file that rasterio can open (most often: @@ -272,7 +279,13 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc if lock is None: lock = RASTERIO_LOCK - manager = CachingFileManager(rasterio.open, filename, lock=lock, mode="r") + manager = CachingFileManager( + rasterio.open, + filename, + lock=lock, + mode="r", + kwargs=kwargs, + ) riods = manager.acquire() if vrt_params is not None: riods = WarpedVRT(riods, **vrt_params) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index d492e3dfb92..aec12d2b154 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -737,6 +737,7 @@ def open_zarr( See Also -------- open_dataset + open_mfdataset References ---------- diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 8f2ba2f4b97..a53ac094253 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -18,7 +18,7 @@ import pandas as pd from . import dtypes -from .indexes import Index, PandasIndex, get_indexer_nd, wrap_pandas_index +from .indexes import Index, PandasIndex, get_indexer_nd from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str, safe_cast_to_index from .variable import IndexVariable, Variable @@ -53,7 +53,10 @@ def _get_joiner(join, index_cls): def _override_indexes(objects, all_indexes, exclude): for dim, dim_indexes in all_indexes.items(): if dim not in exclude: - lengths = {index.size for index in dim_indexes} + lengths = { + getattr(index, "size", index.to_pandas_index().size) + for index in dim_indexes + } if len(lengths) != 1: raise ValueError( f"Indexes along dimension {dim!r} don't have the same length." @@ -300,16 +303,14 @@ def align( joined_indexes = {} for dim, matching_indexes in all_indexes.items(): if dim in indexes: - # TODO: benbovy - flexible indexes. maybe move this logic in util func - if isinstance(indexes[dim], Index): - index = indexes[dim] - else: - index = PandasIndex(safe_cast_to_index(indexes[dim])) + index, _ = PandasIndex.from_pandas_index( + safe_cast_to_index(indexes[dim]), dim + ) if ( any(not index.equals(other) for other in matching_indexes) or dim in unlabeled_dim_sizes ): - joined_indexes[dim] = index + joined_indexes[dim] = indexes[dim] else: if ( any( @@ -323,17 +324,18 @@ def align( joiner = _get_joiner(join, type(matching_indexes[0])) index = joiner(matching_indexes) # make sure str coords are not cast to object - index = maybe_coerce_to_str(index, all_coords[dim]) + index = maybe_coerce_to_str(index.to_pandas_index(), all_coords[dim]) joined_indexes[dim] = index else: index = all_coords[dim][0] if dim in unlabeled_dim_sizes: unlabeled_sizes = unlabeled_dim_sizes[dim] - # TODO: benbovy - flexible indexes: expose a size property for xarray.Index? - # Some indexes may not have a defined size (e.g., built from multiple coords of - # different sizes) - labeled_size = index.size + # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647 + if isinstance(index, PandasIndex): + labeled_size = index.to_pandas_index().size + else: + labeled_size = index.size if len(unlabeled_sizes | {labeled_size}) > 1: raise ValueError( f"arguments without labels along dimension {dim!r} cannot be " @@ -350,7 +352,14 @@ def align( result = [] for obj in objects: - valid_indexers = {k: v for k, v in joined_indexes.items() if k in obj.dims} + # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647 + valid_indexers = {} + for k, index in joined_indexes.items(): + if k in obj.dims: + if isinstance(index, Index): + valid_indexers[k] = index.to_pandas_index() + else: + valid_indexers[k] = index if not valid_indexers: # fast path for no reindexing necessary new_obj = obj.copy(deep=copy) @@ -471,7 +480,11 @@ def reindex_like_indexers( ValueError If any dimensions without labels have different sizes. """ - indexers = {k: v for k, v in other.xindexes.items() if k in target.dims} + # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5647 + # this doesn't support yet indexes other than pd.Index + indexers = { + k: v.to_pandas_index() for k, v in other.xindexes.items() if k in target.dims + } for dim in other.dims: if dim not in indexers and dim in target.dims: @@ -560,7 +573,8 @@ def reindex_variables( "from that to be indexed along {:s}".format(str(indexer.dims), dim) ) - target = new_indexes[dim] = wrap_pandas_index(safe_cast_to_index(indexers[dim])) + target = safe_cast_to_index(indexers[dim]) + new_indexes[dim] = PandasIndex(target, dim) if dim in indexes: # TODO (benbovy - flexible indexes): support other indexes than pd.Index? diff --git a/xarray/core/combine.py b/xarray/core/combine.py index de6d16ef5c3..7e1565e50de 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -77,9 +77,8 @@ def _infer_concat_order_from_coords(datasets): "inferring concatenation order" ) - # TODO (benbovy, flexible indexes): all indexes should be Pandas.Index - # get pd.Index objects from Index objects - indexes = [index.array for index in indexes] + # TODO (benbovy, flexible indexes): support flexible indexes? + indexes = [index.to_pandas_index() for index in indexes] # If dimension coordinate values are same on every dataset then # should be leaving this dimension alone (it's just a "bystander") @@ -635,7 +634,7 @@ def _combine_single_variable_hypercube( return concatenated -# TODO remove empty list default param after version 0.19, see PR4696 +# TODO remove empty list default param after version 0.21, see PR4696 def combine_by_coords( data_objects=[], compat="no_conflicts", @@ -849,11 +848,11 @@ def combine_by_coords( precipitation (y, x) float64 0.4376 0.8918 0.9637 ... 0.5684 0.01879 0.6176 """ - # TODO remove after version 0.19, see PR4696 + # TODO remove after version 0.21, see PR4696 if datasets is not None: warnings.warn( "The datasets argument has been renamed to `data_objects`." - " In future passing a value for datasets will raise an error." + " From 0.21 on passing a value for datasets will raise an error." ) data_objects = datasets diff --git a/xarray/core/common.py b/xarray/core/common.py index 7b6e9198b43..ab822f576d3 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -821,7 +821,6 @@ def rolling( dim: Mapping[Hashable, int] = None, min_periods: int = None, center: Union[bool, Mapping[Hashable, bool]] = False, - keep_attrs: bool = None, **window_kwargs: int, ): """ @@ -889,9 +888,7 @@ def rolling( """ dim = either_dict_or_kwargs(dim, window_kwargs, "rolling") - return self._rolling_cls( - self, dim, min_periods=min_periods, center=center, keep_attrs=keep_attrs - ) + return self._rolling_cls(self, dim, min_periods=min_periods, center=center) def rolling_exp( self, @@ -940,7 +937,6 @@ def coarsen( boundary: str = "exact", side: Union[str, Mapping[Hashable, str]] = "left", coord_func: str = "mean", - keep_attrs: bool = None, **window_kwargs: int, ): """ @@ -1009,7 +1005,6 @@ def coarsen( boundary=boundary, side=side, coord_func=coord_func, - keep_attrs=keep_attrs, ) def resample( diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index b4d553c235a..900af885319 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -51,13 +51,7 @@ ) from .dataset import Dataset, split_indexes from .formatting import format_item -from .indexes import ( - Index, - Indexes, - default_indexes, - propagate_indexes, - wrap_pandas_index, -) +from .indexes import Index, Indexes, default_indexes, propagate_indexes from .indexing import is_fancy_indexer from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords from .options import OPTIONS, _get_keep_attrs @@ -426,12 +420,12 @@ def __init__( self._close = None def _replace( - self, + self: T_DataArray, variable: Variable = None, coords=None, name: Union[Hashable, None, Default] = _default, indexes=None, - ) -> "DataArray": + ) -> T_DataArray: if variable is None: variable = self.variable if coords is None: @@ -473,15 +467,14 @@ def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray": return self coords = self._coords.copy() for name, idx in indexes.items(): - coords[name] = IndexVariable(name, idx) + coords[name] = IndexVariable(name, idx.to_pandas_index()) obj = self._replace(coords=coords) # switch from dimension to level names, if necessary dim_names: Dict[Any, str] = {} for dim, idx in indexes.items(): - # TODO: benbovy - flexible indexes: update when MultiIndex has its own class - pd_idx = idx.array - if not isinstance(pd_idx, pd.MultiIndex) and pd_idx.name != dim: + pd_idx = idx.to_pandas_index() + if not isinstance(idx, pd.MultiIndex) and pd_idx.name != dim: dim_names[dim] = idx.name if dim_names: obj = obj.rename(dim_names) @@ -623,7 +616,16 @@ def __len__(self) -> int: @property def data(self) -> Any: - """The array's data as a dask or numpy array""" + """ + The DataArray's data as an array. The underlying array type + (e.g. dask, sparse, pint) is preserved. + + See Also + -------- + DataArray.to_numpy + DataArray.as_numpy + DataArray.values + """ return self.variable.data @data.setter @@ -632,13 +634,46 @@ def data(self, value: Any) -> None: @property def values(self) -> np.ndarray: - """The array's data as a numpy.ndarray""" + """ + The array's data as a numpy.ndarray. + + If the array's data is not a numpy.ndarray this will attempt to convert + it naively using np.array(), which will raise an error if the array + type does not support coercion like this (e.g. cupy). + """ return self.variable.values @values.setter def values(self, value: Any) -> None: self.variable.values = value + def to_numpy(self) -> np.ndarray: + """ + Coerces wrapped data to numpy and returns a numpy.ndarray. + + See also + -------- + DataArray.as_numpy : Same but returns the surrounding DataArray instead. + Dataset.as_numpy + DataArray.values + DataArray.data + """ + return self.variable.to_numpy() + + def as_numpy(self: T_DataArray) -> T_DataArray: + """ + Coerces wrapped data and coordinates into numpy arrays, returning a DataArray. + + See also + -------- + DataArray.to_numpy : Same but returns only the data as a numpy.ndarray object. + Dataset.as_numpy : Converts all variables in a Dataset. + DataArray.values + DataArray.data + """ + coords = {k: v.as_numpy() for k, v in self._coords.items()} + return self._replace(self.variable.as_numpy(), coords, indexes=self._indexes) + @property def _in_memory(self) -> bool: return self.variable._in_memory @@ -931,7 +966,7 @@ def persist(self, **kwargs) -> "DataArray": ds = self._to_temp_dataset().persist(**kwargs) return self._from_temp_dataset(ds) - def copy(self, deep: bool = True, data: Any = None) -> "DataArray": + def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray: """Returns a copy of this array. If `deep=True`, a deep copy is made of the data array. @@ -1004,12 +1039,7 @@ def copy(self, deep: bool = True, data: Any = None) -> "DataArray": if self._indexes is None: indexes = self._indexes else: - # TODO: benbovy: flexible indexes: support all xarray indexes (not just pandas.Index) - # xarray Index needs a copy method. - indexes = { - k: wrap_pandas_index(v.to_pandas_index().copy(deep=deep)) - for k, v in self._indexes.items() - } + indexes = {k: v.copy(deep=deep) for k, v in self._indexes.items()} return self._replace(variable, coords, indexes=indexes) def __copy__(self) -> "DataArray": @@ -2742,7 +2772,7 @@ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray: result : MaskedArray Masked where invalid values (nan or inf) occur. """ - values = self.values # only compute lazy arrays once + values = self.to_numpy() # only compute lazy arrays once isnull = pd.isnull(values) return np.ma.MaskedArray(data=values, mask=isnull, copy=copy) @@ -3540,8 +3570,6 @@ def integrate( self, coord: Union[Hashable, Sequence[Hashable]] = None, datetime_unit: str = None, - *, - dim: Union[Hashable, Sequence[Hashable]] = None, ) -> "DataArray": """Integrate along the given coordinate using the trapezoidal rule. @@ -3553,8 +3581,6 @@ def integrate( ---------- coord : hashable, or sequence of hashable Coordinate(s) used for the integration. - dim : hashable, or sequence of hashable - Coordinate(s) used for the integration. datetime_unit : {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \ 'ps', 'fs', 'as'}, optional Specify the unit if a datetime coordinate is used. @@ -3591,21 +3617,6 @@ def integrate( array([5.4, 6.6, 7.8]) Dimensions without coordinates: y """ - if dim is not None and coord is not None: - raise ValueError( - "Cannot pass both 'dim' and 'coord'. Please pass only 'coord' instead." - ) - - if dim is not None and coord is None: - coord = dim - msg = ( - "The `dim` keyword argument to `DataArray.integrate` is " - "being replaced with `coord`, for consistency with " - "`Dataset.integrate`. Please pass `coord` instead." - " `dim` will be removed in version 0.19.0." - ) - warnings.warn(msg, FutureWarning, stacklevel=2) - ds = self._to_temp_dataset().integrate(coord, datetime_unit) return self._from_temp_dataset(ds) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 13da8cfad03..533ecadbae5 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -71,7 +71,6 @@ propagate_indexes, remove_unused_levels_categories, roll_index, - wrap_pandas_index, ) from .indexing import is_fancy_indexer from .merge import ( @@ -1184,7 +1183,7 @@ def _overwrite_indexes(self, indexes: Mapping[Any, Index]) -> "Dataset": variables = self._variables.copy() new_indexes = dict(self.xindexes) for name, idx in indexes.items(): - variables[name] = IndexVariable(name, idx) + variables[name] = IndexVariable(name, idx.to_pandas_index()) new_indexes[name] = idx obj = self._replace(variables, indexes=new_indexes) @@ -1323,6 +1322,18 @@ def copy(self, deep: bool = False, data: Mapping = None) -> "Dataset": return self._replace(variables, attrs=attrs) + def as_numpy(self: "Dataset") -> "Dataset": + """ + Coerces wrapped data and coordinates into numpy arrays, returning a Dataset. + + See also + -------- + DataArray.as_numpy + DataArray.to_numpy : Returns only the data as a numpy.ndarray object. + """ + numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} + return self._replace(variables=numpy_variables) + @property def _level_coords(self) -> Dict[str, Hashable]: """Return a mapping of all MultiIndex levels and their corresponding @@ -2462,6 +2473,10 @@ def sel( pos_indexers, new_indexes = remap_label_indexers( self, indexers=indexers, method=method, tolerance=tolerance ) + # TODO: benbovy - flexible indexes: also use variables returned by Index.query + # (temporary dirty fix). + new_indexes = {k: v[0] for k, v in new_indexes.items()} + result = self.isel(indexers=pos_indexers, drop=drop) return result._overwrite_indexes(new_indexes) @@ -3285,20 +3300,21 @@ def _rename_dims(self, name_dict): return {name_dict.get(k, k): v for k, v in self.dims.items()} def _rename_indexes(self, name_dict, dims_set): + # TODO: benbovy - flexible indexes: https://github.com/pydata/xarray/issues/5645 if self._indexes is None: return None indexes = {} - for k, v in self.xindexes.items(): - # TODO: benbovy - flexible indexes: make it compatible with any xarray Index - index = v.to_pandas_index() + for k, v in self.indexes.items(): new_name = name_dict.get(k, k) if new_name not in dims_set: continue - if isinstance(index, pd.MultiIndex): - new_names = [name_dict.get(k, k) for k in index.names] - indexes[new_name] = PandasMultiIndex(index.rename(names=new_names)) + if isinstance(v, pd.MultiIndex): + new_names = [name_dict.get(k, k) for k in v.names] + indexes[new_name] = PandasMultiIndex( + v.rename(names=new_names), new_name + ) else: - indexes[new_name] = PandasIndex(index.rename(new_name)) + indexes[new_name] = PandasIndex(v.rename(new_name), new_name) return indexes def _rename_all(self, name_dict, dims_dict): @@ -3527,7 +3543,10 @@ def swap_dims( if new_index.nlevels == 1: # make sure index name matches dimension name new_index = new_index.rename(k) - indexes[k] = wrap_pandas_index(new_index) + if isinstance(new_index, pd.MultiIndex): + indexes[k] = PandasMultiIndex(new_index, k) + else: + indexes[k] = PandasIndex(new_index, k) else: var = v.to_base_variable() var.dims = dims @@ -3800,7 +3819,7 @@ def reorder_levels( raise ValueError(f"coordinate {dim} has no MultiIndex") new_index = index.reorder_levels(order) variables[dim] = IndexVariable(coord.dims, new_index) - indexes[dim] = PandasMultiIndex(new_index) + indexes[dim] = PandasMultiIndex(new_index, dim) return self._replace(variables, indexes=indexes) @@ -3828,7 +3847,7 @@ def _stack_once(self, dims, new_dim): coord_names = set(self._coord_names) - set(dims) | {new_dim} indexes = {k: v for k, v in self.xindexes.items() if k not in dims} - indexes[new_dim] = wrap_pandas_index(idx) + indexes[new_dim] = PandasMultiIndex(idx, new_dim) return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes @@ -4017,8 +4036,9 @@ def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset": variables[name] = var for name, lev in zip(index.names, index.levels): - variables[name] = IndexVariable(name, lev) - indexes[name] = PandasIndex(lev) + idx, idx_vars = PandasIndex.from_pandas_index(lev, name) + variables[name] = idx_vars[name] + indexes[name] = idx coord_names = set(self._coord_names) - {dim} | set(index.names) @@ -4056,8 +4076,9 @@ def _unstack_full_reindex( variables[name] = var for name, lev in zip(new_dim_names, index.levels): - variables[name] = IndexVariable(name, lev) - indexes[name] = PandasIndex(lev) + idx, idx_vars = PandasIndex.from_pandas_index(lev, name) + variables[name] = idx_vars[name] + indexes[name] = idx coord_names = set(self._coord_names) - {dim} | set(new_dim_names) @@ -4161,6 +4182,7 @@ def update(self, other: "CoercibleMapping") -> "Dataset": """Update this dataset's variables with those from another dataset. Just like :py:meth:`dict.update` this is a in-place operation. + For a non-inplace version, see :py:meth:`Dataset.merge`. Parameters ---------- @@ -4179,7 +4201,7 @@ def update(self, other: "CoercibleMapping") -> "Dataset": Updated dataset. Note that since the update is in-place this is the input dataset. - It is deprecated since version 0.17 and scheduled to be removed in 0.19. + It is deprecated since version 0.17 and scheduled to be removed in 0.21. Raises ------ @@ -4190,6 +4212,7 @@ def update(self, other: "CoercibleMapping") -> "Dataset": See Also -------- Dataset.assign + Dataset.merge """ merge_result = dataset_update_method(self, other) return self._replace(inplace=True, **merge_result._asdict()) @@ -4263,6 +4286,10 @@ def merge( ------ MergeError If any variables conflict (see ``compat``). + + See Also + -------- + Dataset.update """ other = other.to_dataset() if isinstance(other, xr.DataArray) else other merge_result = dataset_merge_method( @@ -4543,7 +4570,11 @@ def drop_dims( drop_vars = {k for k, v in self._variables.items() if set(v.dims) & drop_dims} return self.drop_vars(drop_vars) - def transpose(self, *dims: Hashable) -> "Dataset": + def transpose( + self, + *dims: Hashable, + missing_dims: str = "raise", + ) -> "Dataset": """Return a new Dataset object with all array dimensions transposed. Although the order of dimensions on each array will change, the dataset @@ -4554,6 +4585,12 @@ def transpose(self, *dims: Hashable) -> "Dataset": *dims : hashable, optional By default, reverse the dimensions on each array. Otherwise, reorder the dimensions to this order. + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + Dataset: + - "raise": raise an exception + - "warn": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions Returns ------- @@ -4572,12 +4609,10 @@ def transpose(self, *dims: Hashable) -> "Dataset": numpy.transpose DataArray.transpose """ - if dims: - if set(dims) ^ set(self.dims) and ... not in dims: - raise ValueError( - f"arguments to transpose ({dims}) must be " - f"permuted dataset dimensions ({tuple(self.dims)})" - ) + # Use infix_dims to check once for missing dimensions + if len(dims) != 0: + _ = list(infix_dims(dims, self.dims, missing_dims)) + ds = self.copy() for name, var in self._variables.items(): var_dims = tuple(dim for dim in dims if dim in (var.dims + (...,))) @@ -5813,10 +5848,13 @@ def diff(self, dim, n=1, label="upper"): indexes = dict(self.xindexes) if dim in indexes: - # TODO: benbovy - flexible indexes: check slicing of xarray indexes? - # or only allow this for pandas indexes? - index = indexes[dim].to_pandas_index() - indexes[dim] = PandasIndex(index[kwargs_new[dim]]) + if isinstance(indexes[dim], PandasIndex): + # maybe optimize? (pandas index already indexed above with var.isel) + new_index = indexes[dim].index[kwargs_new[dim]] + if isinstance(new_index, pd.MultiIndex): + indexes[dim] = PandasMultiIndex(new_index, dim) + else: + indexes[dim] = PandasIndex(new_index, dim) difference = self._replace_with_new_dims(variables, indexes=indexes) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index ab30facf49f..70d1a61f56c 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -385,12 +385,14 @@ def _mapping_repr( elif max_rows is not None and len_mapping > max_rows: summary = [f"{summary[0]} ({max_rows}/{len_mapping})"] first_rows = max_rows // 2 + max_rows % 2 - items = list(mapping.items()) - summary += [summarizer(k, v, col_width) for k, v in items[:first_rows]] + keys = list(mapping.keys()) + summary += [summarizer(k, mapping[k], col_width) for k in keys[:first_rows]] if max_rows > 1: last_rows = max_rows // 2 summary += [pretty_print(" ...", col_width) + " ..."] - summary += [summarizer(k, v, col_width) for k, v in items[-last_rows:]] + summary += [ + summarizer(k, mapping[k], col_width) for k in keys[-last_rows:] + ] else: summary += [summarizer(k, v, col_width) for k, v in mapping.items()] else: diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 90d8eec6623..429c37af588 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -1,6 +1,4 @@ import collections.abc -from contextlib import suppress -from datetime import timedelta from typing import ( TYPE_CHECKING, Any, @@ -18,28 +16,26 @@ import pandas as pd from . import formatting, utils -from .indexing import ExplicitlyIndexedNDArrayMixin, NumpyIndexingAdapter -from .npcompat import DTypeLike +from .indexing import ( + LazilyIndexedArray, + PandasIndexingAdapter, + PandasMultiIndexingAdapter, +) from .utils import is_dict_like, is_scalar if TYPE_CHECKING: - from .variable import Variable + from .variable import IndexVariable, Variable + +IndexVars = Dict[Hashable, "IndexVariable"] class Index: """Base class inherited by all xarray-compatible indexes.""" - __slots__ = ("coord_names",) - - def __init__(self, coord_names: Union[Hashable, Iterable[Hashable]]): - if isinstance(coord_names, Hashable): - coord_names = (coord_names,) - self.coord_names = tuple(coord_names) - @classmethod def from_variables( - cls, variables: Dict[Hashable, "Variable"], **kwargs - ): # pragma: no cover + cls, variables: Mapping[Hashable, "Variable"] + ) -> Tuple["Index", Optional[IndexVars]]: # pragma: no cover raise NotImplementedError() def to_pandas_index(self) -> pd.Index: @@ -52,8 +48,10 @@ def to_pandas_index(self) -> pd.Index: """ raise TypeError(f"{type(self)} cannot be cast to a pandas.Index object.") - def query(self, labels: Dict[Hashable, Any]): # pragma: no cover - raise NotImplementedError + def query( + self, labels: Dict[Hashable, Any] + ) -> Tuple[Any, Optional[Tuple["Index", IndexVars]]]: # pragma: no cover + raise NotImplementedError() def equals(self, other): # pragma: no cover raise NotImplementedError() @@ -64,6 +62,13 @@ def union(self, other): # pragma: no cover def intersection(self, other): # pragma: no cover raise NotImplementedError() + def copy(self, deep: bool = True): # pragma: no cover + raise NotImplementedError() + + def __getitem__(self, indexer: Any): + # if not implemented, index will be dropped from the Dataset or DataArray + raise NotImplementedError() + def _sanitize_slice_element(x): from .dataarray import DataArray @@ -138,64 +143,68 @@ def get_indexer_nd(index, labels, method=None, tolerance=None): return indexer -class PandasIndex(Index, ExplicitlyIndexedNDArrayMixin): - """Wrap a pandas.Index to preserve dtypes and handle explicit indexing.""" +class PandasIndex(Index): + """Wrap a pandas.Index as an xarray compatible index.""" - __slots__ = ("array", "_dtype") + __slots__ = ("index", "dim") - def __init__( - self, array: Any, dtype: DTypeLike = None, coord_name: Optional[Hashable] = None - ): - if coord_name is None: - coord_name = tuple() - super().__init__(coord_name) + def __init__(self, array: Any, dim: Hashable): + self.index = utils.safe_cast_to_index(array) + self.dim = dim - self.array = utils.safe_cast_to_index(array) + @classmethod + def from_variables(cls, variables: Mapping[Hashable, "Variable"]): + from .variable import IndexVariable - if dtype is None: - if isinstance(array, pd.PeriodIndex): - dtype_ = np.dtype("O") - elif hasattr(array, "categories"): - # category isn't a real numpy dtype - dtype_ = array.categories.dtype - elif not utils.is_valid_numpy_dtype(array.dtype): - dtype_ = np.dtype("O") - else: - dtype_ = array.dtype + if len(variables) != 1: + raise ValueError( + f"PandasIndex only accepts one variable, found {len(variables)} variables" + ) + + name, var = next(iter(variables.items())) + + if var.ndim != 1: + raise ValueError( + "PandasIndex only accepts a 1-dimensional variable, " + f"variable {name!r} has {var.ndim} dimensions" + ) + + dim = var.dims[0] + + obj = cls(var.data, dim) + + data = PandasIndexingAdapter(obj.index) + index_var = IndexVariable( + dim, data, attrs=var.attrs, encoding=var.encoding, fastpath=True + ) + + return obj, {name: index_var} + + @classmethod + def from_pandas_index(cls, index: pd.Index, dim: Hashable): + from .variable import IndexVariable + + if index.name is None: + name = dim + index = index.copy() + index.name = dim else: - dtype_ = np.dtype(dtype) # type: ignore[assignment] - self._dtype = dtype_ + name = index.name + + data = PandasIndexingAdapter(index) + index_var = IndexVariable(dim, data, fastpath=True) + + return cls(index, dim), {name: index_var} def to_pandas_index(self) -> pd.Index: - return self.array - - @property - def dtype(self) -> np.dtype: - return self._dtype - - def __array__(self, dtype: DTypeLike = None) -> np.ndarray: - if dtype is None: - dtype = self.dtype - array = self.array - if isinstance(array, pd.PeriodIndex): - with suppress(AttributeError): - # this might not be public API - array = array.astype("object") - return np.asarray(array.values, dtype=dtype) - - @property - def shape(self) -> Tuple[int]: - return (len(self.array),) + return self.index - def query( - self, labels, method=None, tolerance=None - ) -> Tuple[Any, Union["PandasIndex", None]]: + def query(self, labels, method=None, tolerance=None): assert len(labels) == 1 coord_name, label = next(iter(labels.items())) - index = self.array if isinstance(label, slice): - indexer = _query_slice(index, label, coord_name, method, tolerance) + indexer = _query_slice(self.index, label, coord_name, method, tolerance) elif is_dict_like(label): raise ValueError( "cannot use a dict-like object for selection on " @@ -210,7 +219,7 @@ def query( if label.ndim == 0: # see https://github.com/pydata/xarray/pull/4292 for details label_value = label[()] if label.dtype.kind in "mM" else label.item() - if isinstance(index, pd.CategoricalIndex): + if isinstance(self.index, pd.CategoricalIndex): if method is not None: raise ValueError( "'method' is not a valid kwarg when indexing using a CategoricalIndex." @@ -219,115 +228,114 @@ def query( raise ValueError( "'tolerance' is not a valid kwarg when indexing using a CategoricalIndex." ) - indexer = index.get_loc(label_value) + indexer = self.index.get_loc(label_value) else: - indexer = index.get_loc( + indexer = self.index.get_loc( label_value, method=method, tolerance=tolerance ) elif label.dtype.kind == "b": indexer = label else: - indexer = get_indexer_nd(index, label, method, tolerance) + indexer = get_indexer_nd(self.index, label, method, tolerance) if np.any(indexer < 0): raise KeyError(f"not all values found in index {coord_name!r}") return indexer, None def equals(self, other): - if isinstance(other, pd.Index): - other = type(self)(other) - return self.array.equals(other.array) + return self.index.equals(other.index) def union(self, other): - if isinstance(other, pd.Index): - other = type(self)(other) - return type(self)(self.array.union(other.array)) + new_index = self.index.union(other.index) + return type(self)(new_index, self.dim) def intersection(self, other): - if isinstance(other, pd.Index): - other = PandasIndex(other) - return type(self)(self.array.intersection(other.array)) - - def __getitem__( - self, indexer - ) -> Union[ - "PandasIndex", - NumpyIndexingAdapter, - np.ndarray, - np.datetime64, - np.timedelta64, - ]: - key = indexer.tuple - if isinstance(key, tuple) and len(key) == 1: - # unpack key so it can index a pandas.Index object (pandas.Index - # objects don't like tuples) - (key,) = key - - if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional - return NumpyIndexingAdapter(self.array.values)[indexer] - - result = self.array[key] - - if isinstance(result, pd.Index): - result = type(self)(result, dtype=self.dtype) - else: - # result is a scalar - if result is pd.NaT: - # work around the impossibility of casting NaT with asarray - # note: it probably would be better in general to return - # pd.Timestamp rather np.than datetime64 but this is easier - # (for now) - result = np.datetime64("NaT", "ns") - elif isinstance(result, timedelta): - result = np.timedelta64(getattr(result, "value", result), "ns") - elif isinstance(result, pd.Timestamp): - # Work around for GH: pydata/xarray#1932 and numpy/numpy#10668 - # numpy fails to convert pd.Timestamp to np.datetime64[ns] - result = np.asarray(result.to_datetime64()) - elif self.dtype != object: - result = np.asarray(result, dtype=self.dtype) - - # as for numpy.ndarray indexing, we always want the result to be - # a NumPy array. - result = utils.to_0d_array(result) - - return result - - def transpose(self, order) -> pd.Index: - return self.array # self.array should be always one-dimensional - - def __repr__(self) -> str: - return f"{type(self).__name__}(array={self.array!r}, dtype={self.dtype!r})" - - def copy(self, deep: bool = True) -> "PandasIndex": - # Not the same as just writing `self.array.copy(deep=deep)`, as - # shallow copies of the underlying numpy.ndarrays become deep ones - # upon pickling - # >>> len(pickle.dumps((self.array, self.array))) - # 4000281 - # >>> len(pickle.dumps((self.array, self.array.copy(deep=False)))) - # 8000341 - array = self.array.copy(deep=True) if deep else self.array - return type(self)(array, self._dtype) + new_index = self.index.intersection(other.index) + return type(self)(new_index, self.dim) + + def copy(self, deep=True): + return type(self)(self.index.copy(deep=deep), self.dim) + + def __getitem__(self, indexer: Any): + return type(self)(self.index[indexer], self.dim) + + +def _create_variables_from_multiindex(index, dim, level_meta=None): + from .variable import IndexVariable + + if level_meta is None: + level_meta = {} + + variables = {} + + dim_coord_adapter = PandasMultiIndexingAdapter(index) + variables[dim] = IndexVariable( + dim, LazilyIndexedArray(dim_coord_adapter), fastpath=True + ) + + for level in index.names: + meta = level_meta.get(level, {}) + data = PandasMultiIndexingAdapter( + index, dtype=meta.get("dtype"), level=level, adapter=dim_coord_adapter + ) + variables[level] = IndexVariable( + dim, + data, + attrs=meta.get("attrs"), + encoding=meta.get("encoding"), + fastpath=True, + ) + + return variables class PandasMultiIndex(PandasIndex): - def query( - self, labels, method=None, tolerance=None - ) -> Tuple[Any, Union["PandasIndex", None]]: + @classmethod + def from_variables(cls, variables: Mapping[Hashable, "Variable"]): + if any([var.ndim != 1 for var in variables.values()]): + raise ValueError("PandasMultiIndex only accepts 1-dimensional variables") + + dims = set([var.dims for var in variables.values()]) + if len(dims) != 1: + raise ValueError( + "unmatched dimensions for variables " + + ",".join([str(k) for k in variables]) + ) + + dim = next(iter(dims))[0] + index = pd.MultiIndex.from_arrays( + [var.values for var in variables.values()], names=variables.keys() + ) + obj = cls(index, dim) + + level_meta = { + name: {"dtype": var.dtype, "attrs": var.attrs, "encoding": var.encoding} + for name, var in variables.items() + } + index_vars = _create_variables_from_multiindex( + index, dim, level_meta=level_meta + ) + + return obj, index_vars + + @classmethod + def from_pandas_index(cls, index: pd.MultiIndex, dim: Hashable): + index_vars = _create_variables_from_multiindex(index, dim) + return cls(index, dim), index_vars + + def query(self, labels, method=None, tolerance=None): if method is not None or tolerance is not None: raise ValueError( "multi-index does not support ``method`` and ``tolerance``" ) - index = self.array new_index = None # label(s) given for multi-index level(s) - if all([lbl in index.names for lbl in labels]): + if all([lbl in self.index.names for lbl in labels]): is_nested_vals = _is_nested_tuple(tuple(labels.values())) - if len(labels) == index.nlevels and not is_nested_vals: - indexer = index.get_loc(tuple(labels[k] for k in index.names)) + if len(labels) == self.index.nlevels and not is_nested_vals: + indexer = self.index.get_loc(tuple(labels[k] for k in self.index.names)) else: for k, v in labels.items(): # index should be an item (i.e. Hashable) not an array-like @@ -336,7 +344,7 @@ def query( "Vectorized selection is not " f"available along coordinate {k!r} (multi-index level)" ) - indexer, new_index = index.get_loc_level( + indexer, new_index = self.index.get_loc_level( tuple(labels.values()), level=tuple(labels.keys()) ) # GH2619. Raise a KeyError if nothing is chosen @@ -346,16 +354,18 @@ def query( # assume one label value given for the multi-index "array" (dimension) else: if len(labels) > 1: - coord_name = next(iter(set(labels) - set(index.names))) + coord_name = next(iter(set(labels) - set(self.index.names))) raise ValueError( f"cannot provide labels for both coordinate {coord_name!r} (multi-index array) " - f"and one or more coordinates among {index.names!r} (multi-index levels)" + f"and one or more coordinates among {self.index.names!r} (multi-index levels)" ) coord_name, label = next(iter(labels.items())) if is_dict_like(label): - invalid_levels = [name for name in label if name not in index.names] + invalid_levels = [ + name for name in label if name not in self.index.names + ] if invalid_levels: raise ValueError( f"invalid multi-index level names {invalid_levels}" @@ -363,15 +373,15 @@ def query( return self.query(label) elif isinstance(label, slice): - indexer = _query_slice(index, label, coord_name) + indexer = _query_slice(self.index, label, coord_name) elif isinstance(label, tuple): if _is_nested_tuple(label): - indexer = index.get_locs(label) - elif len(label) == index.nlevels: - indexer = index.get_loc(label) + indexer = self.index.get_locs(label) + elif len(label) == self.index.nlevels: + indexer = self.index.get_loc(label) else: - indexer, new_index = index.get_loc_level( + indexer, new_index = self.index.get_loc_level( label, level=list(range(len(label))) ) @@ -382,7 +392,7 @@ def query( else _asarray_tuplesafe(label) ) if label.ndim == 0: - indexer, new_index = index.get_loc_level(label.item(), level=0) + indexer, new_index = self.index.get_loc_level(label.item(), level=0) elif label.dtype.kind == "b": indexer = label else: @@ -391,21 +401,20 @@ def query( "Vectorized selection is not available along " f"coordinate {coord_name!r} with a multi-index" ) - indexer = get_indexer_nd(index, label) + indexer = get_indexer_nd(self.index, label) if np.any(indexer < 0): raise KeyError(f"not all values found in index {coord_name!r}") if new_index is not None: - new_index = PandasIndex(new_index) - - return indexer, new_index - - -def wrap_pandas_index(index): - if isinstance(index, pd.MultiIndex): - return PandasMultiIndex(index) - else: - return PandasIndex(index) + if isinstance(new_index, pd.MultiIndex): + new_index, new_vars = PandasMultiIndex.from_pandas_index( + new_index, self.dim + ) + else: + new_index, new_vars = PandasIndex.from_pandas_index(new_index, self.dim) + return indexer, (new_index, new_vars) + else: + return indexer, None def remove_unused_levels_categories(index: pd.Index) -> pd.Index: @@ -492,7 +501,13 @@ def isel_variable_and_index( index: Index, indexers: Mapping[Hashable, Union[int, slice, np.ndarray, "Variable"]], ) -> Tuple["Variable", Optional[Index]]: - """Index a Variable and pandas.Index together.""" + """Index a Variable and an Index together. + + If the index cannot be indexed, return None (it will be dropped). + + (note: not compatible yet with xarray flexible indexes). + + """ from .variable import Variable if not indexers: @@ -515,8 +530,11 @@ def isel_variable_and_index( indexer = indexers[dim] if isinstance(indexer, Variable): indexer = indexer.data - pd_index = index.to_pandas_index() - new_index = wrap_pandas_index(pd_index[indexer]) + try: + new_index = index[indexer] + except NotImplementedError: + new_index = None + return new_variable, new_index @@ -528,7 +546,7 @@ def roll_index(index: PandasIndex, count: int, axis: int = 0) -> PandasIndex: new_idx = pd_index[-count:].append(pd_index[:-count]) else: new_idx = pd_index[:] - return PandasIndex(new_idx) + return PandasIndex(new_idx, index.dim) def propagate_indexes( diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 1ace4db241d..70994a36ac8 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -2,12 +2,15 @@ import functools import operator from collections import defaultdict -from typing import Any, Callable, Iterable, List, Tuple, Union +from contextlib import suppress +from datetime import timedelta +from typing import Any, Callable, Iterable, List, Optional, Tuple, Union import numpy as np import pandas as pd from . import duck_array_ops, nputils, utils +from .npcompat import DTypeLike from .pycompat import ( dask_array_type, dask_version, @@ -569,9 +572,7 @@ def as_indexable(array): if isinstance(array, np.ndarray): return NumpyIndexingAdapter(array) if isinstance(array, pd.Index): - from .indexes import PandasIndex - - return PandasIndex(array) + return PandasIndexingAdapter(array) if isinstance(array, dask_array_type): return DaskIndexingAdapter(array) if hasattr(array, "__array_function__"): @@ -1259,3 +1260,149 @@ def __setitem__(self, key, value): def transpose(self, order): return self.array.transpose(order) + + +class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin): + """Wrap a pandas.Index to preserve dtypes and handle explicit indexing.""" + + __slots__ = ("array", "_dtype") + + def __init__(self, array: pd.Index, dtype: DTypeLike = None): + self.array = utils.safe_cast_to_index(array) + + if dtype is None: + if isinstance(array, pd.PeriodIndex): + dtype_ = np.dtype("O") + elif hasattr(array, "categories"): + # category isn't a real numpy dtype + dtype_ = array.categories.dtype + elif not utils.is_valid_numpy_dtype(array.dtype): + dtype_ = np.dtype("O") + else: + dtype_ = array.dtype + else: + dtype_ = np.dtype(dtype) # type: ignore[assignment] + self._dtype = dtype_ + + @property + def dtype(self) -> np.dtype: + return self._dtype + + def __array__(self, dtype: DTypeLike = None) -> np.ndarray: + if dtype is None: + dtype = self.dtype + array = self.array + if isinstance(array, pd.PeriodIndex): + with suppress(AttributeError): + # this might not be public API + array = array.astype("object") + return np.asarray(array.values, dtype=dtype) + + @property + def shape(self) -> Tuple[int]: + return (len(self.array),) + + def __getitem__( + self, indexer + ) -> Union[ + "PandasIndexingAdapter", + NumpyIndexingAdapter, + np.ndarray, + np.datetime64, + np.timedelta64, + ]: + key = indexer.tuple + if isinstance(key, tuple) and len(key) == 1: + # unpack key so it can index a pandas.Index object (pandas.Index + # objects don't like tuples) + (key,) = key + + if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional + return NumpyIndexingAdapter(self.array.values)[indexer] + + result = self.array[key] + + if isinstance(result, pd.Index): + result = type(self)(result, dtype=self.dtype) + else: + # result is a scalar + if result is pd.NaT: + # work around the impossibility of casting NaT with asarray + # note: it probably would be better in general to return + # pd.Timestamp rather np.than datetime64 but this is easier + # (for now) + result = np.datetime64("NaT", "ns") + elif isinstance(result, timedelta): + result = np.timedelta64(getattr(result, "value", result), "ns") + elif isinstance(result, pd.Timestamp): + # Work around for GH: pydata/xarray#1932 and numpy/numpy#10668 + # numpy fails to convert pd.Timestamp to np.datetime64[ns] + result = np.asarray(result.to_datetime64()) + elif self.dtype != object: + result = np.asarray(result, dtype=self.dtype) + + # as for numpy.ndarray indexing, we always want the result to be + # a NumPy array. + result = utils.to_0d_array(result) + + return result + + def transpose(self, order) -> pd.Index: + return self.array # self.array should be always one-dimensional + + def __repr__(self) -> str: + return f"{type(self).__name__}(array={self.array!r}, dtype={self.dtype!r})" + + def copy(self, deep: bool = True) -> "PandasIndexingAdapter": + # Not the same as just writing `self.array.copy(deep=deep)`, as + # shallow copies of the underlying numpy.ndarrays become deep ones + # upon pickling + # >>> len(pickle.dumps((self.array, self.array))) + # 4000281 + # >>> len(pickle.dumps((self.array, self.array.copy(deep=False)))) + # 8000341 + array = self.array.copy(deep=True) if deep else self.array + return type(self)(array, self._dtype) + + +class PandasMultiIndexingAdapter(PandasIndexingAdapter): + """Handles explicit indexing for a pandas.MultiIndex. + + This allows creating one instance for each multi-index level while + preserving indexing efficiency (memoized + might reuse another instance with + the same multi-index). + + """ + + __slots__ = ("array", "_dtype", "level", "adapter") + + def __init__( + self, + array: pd.MultiIndex, + dtype: DTypeLike = None, + level: Optional[str] = None, + adapter: Optional[PandasIndexingAdapter] = None, + ): + super().__init__(array, dtype) + self.level = level + self.adapter = adapter + + def __array__(self, dtype: DTypeLike = None) -> np.ndarray: + if self.level is not None: + return self.array.get_level_values(self.level).values + else: + return super().__array__(dtype) + + @functools.lru_cache(1) + def __getitem__(self, indexer): + if self.adapter is None: + return super().__getitem__(indexer) + else: + return self.adapter.__getitem__(indexer) + + def __repr__(self) -> str: + if self.level is None: + return super().__repr__() + else: + props = "(array={self.array!r}, level={self.level!r}, dtype={self.dtype!r})" + return f"{type(self).__name__}{props}" diff --git a/xarray/core/merge.py b/xarray/core/merge.py index db5b95fd415..b8b32bdaa01 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -578,7 +578,7 @@ def merge_core( combine_attrs: Optional[str] = "override", priority_arg: Optional[int] = None, explicit_coords: Optional[Sequence] = None, - indexes: Optional[Mapping[Hashable, Index]] = None, + indexes: Optional[Mapping[Hashable, Any]] = None, fill_value: object = dtypes.NA, ) -> _MergeResult: """Core logic for merging labeled objects. @@ -601,7 +601,8 @@ def merge_core( explicit_coords : set, optional An explicit list of variables from `objects` that are coordinates. indexes : dict, optional - Dictionary with values given by pandas.Index objects. + Dictionary with values given by xarray.Index objects or anything that + may be cast to pandas.Index objects. fill_value : scalar, optional Value to use for newly missing values @@ -979,8 +980,14 @@ def dataset_update_method( other[key] = value.drop_vars(coord_names) # use ds.coords and not ds.indexes, else str coords are cast to object - # TODO: benbovy - flexible indexes: fix this (it only works with pandas indexes) - indexes = {key: PandasIndex(dataset.coords[key]) for key in dataset.xindexes.keys()} + # TODO: benbovy - flexible indexes: make it work with any xarray index + indexes = {} + for key, index in dataset.xindexes.items(): + if isinstance(index, PandasIndex): + indexes[key] = dataset.coords[key] + else: + indexes[key] = index + return merge_core( [dataset, other], priority_arg=1, diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 795d30af28f..2c7f4249b5e 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -27,8 +27,6 @@ import numpy as np -from xarray.core.indexes import PandasIndex - from .alignment import align from .dataarray import DataArray from .dataset import Dataset @@ -295,9 +293,10 @@ def _wrapper( # check that index lengths and values are as expected for name, index in result.xindexes.items(): if name in expected["shapes"]: - if len(index) != expected["shapes"][name]: + if result.sizes[name] != expected["shapes"][name]: raise ValueError( - f"Received dimension {name!r} of length {len(index)}. Expected length {expected['shapes'][name]}." + f"Received dimension {name!r} of length {result.sizes[name]}. " + f"Expected length {expected['shapes'][name]}." ) if name in expected["indexes"]: expected_index = expected["indexes"][name] @@ -503,16 +502,10 @@ def subset_dataset_to_block( } expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment] expected["coords"] = set(template.coords.keys()) # type: ignore[assignment] - # TODO: benbovy - flexible indexes: clean this up - # for now assumes pandas index (thus can be indexed) but it won't be the case for - # all indexes - expected_indexes = {} - for dim in indexes: - idx = indexes[dim].to_pandas_index()[ - _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) - ] - expected_indexes[dim] = PandasIndex(idx) - expected["indexes"] = expected_indexes + expected["indexes"] = { + dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)] + for dim in indexes + } from_wrapper = (gname,) + chunk_tuple graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected) @@ -557,7 +550,13 @@ def subset_dataset_to_block( }, ) - result = Dataset(coords=indexes, attrs=template.attrs) + # TODO: benbovy - flexible indexes: make it work with custom indexes + # this will need to pass both indexes and coords to the Dataset constructor + result = Dataset( + coords={k: idx.to_pandas_index() for k, idx in indexes.items()}, + attrs=template.attrs, + ) + for index in result.xindexes: result[index].attrs = template[index].attrs result[index].encoding = template[index].encoding @@ -568,8 +567,8 @@ def subset_dataset_to_block( for dim in dims: if dim in output_chunks: var_chunks.append(output_chunks[dim]) - elif dim in indexes: - var_chunks.append((len(indexes[dim]),)) + elif dim in result.xindexes: + var_chunks.append((result.sizes[dim],)) elif dim in template.dims: # new unindexed dimension var_chunks.append((template.sizes[dim],)) diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 9f47da6c8cc..d1649235006 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -1,4 +1,5 @@ from distutils.version import LooseVersion +from importlib import import_module import numpy as np @@ -6,42 +7,57 @@ integer_types = (int, np.integer) -try: - import dask - import dask.array - from dask.base import is_dask_collection - dask_version = LooseVersion(dask.__version__) +class DuckArrayModule: + """ + Solely for internal isinstance and version checks. - # solely for isinstance checks - dask_array_type = (dask.array.Array,) + Motivated by having to only import pint when required (as pint currently imports xarray) + https://github.com/pydata/xarray/pull/5561#discussion_r664815718 + """ - def is_duck_dask_array(x): - return is_duck_array(x) and is_dask_collection(x) + def __init__(self, mod): + try: + duck_array_module = import_module(mod) + duck_array_version = LooseVersion(duck_array_module.__version__) + + if mod == "dask": + duck_array_type = (import_module("dask.array").Array,) + elif mod == "pint": + duck_array_type = (duck_array_module.Quantity,) + elif mod == "cupy": + duck_array_type = (duck_array_module.ndarray,) + elif mod == "sparse": + duck_array_type = (duck_array_module.SparseArray,) + else: + raise NotImplementedError + + except ImportError: # pragma: no cover + duck_array_module = None + duck_array_version = LooseVersion("0.0.0") + duck_array_type = () + self.module = duck_array_module + self.version = duck_array_version + self.type = duck_array_type + self.available = duck_array_module is not None -except ImportError: # pragma: no cover - dask_version = LooseVersion("0.0.0") - dask_array_type = () - is_duck_dask_array = lambda _: False - is_dask_collection = lambda _: False -try: - # solely for isinstance checks - import sparse +def is_duck_dask_array(x): + if DuckArrayModule("dask").available: + from dask.base import is_dask_collection + + return is_duck_array(x) and is_dask_collection(x) + else: + return False + - sparse_version = LooseVersion(sparse.__version__) - sparse_array_type = (sparse.SparseArray,) -except ImportError: # pragma: no cover - sparse_version = LooseVersion("0.0.0") - sparse_array_type = () +dsk = DuckArrayModule("dask") +dask_version = dsk.version +dask_array_type = dsk.type -try: - # solely for isinstance checks - import cupy +sp = DuckArrayModule("sparse") +sparse_array_type = sp.type +sparse_version = sp.version - cupy_version = LooseVersion(cupy.__version__) - cupy_array_type = (cupy.ndarray,) -except ImportError: # pragma: no cover - cupy_version = LooseVersion("0.0.0") - cupy_array_type = () +cupy_array_type = DuckArrayModule("cupy").type diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index b87dcda24b0..04052510f5d 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -48,10 +48,10 @@ class Rolling: xarray.DataArray.rolling """ - __slots__ = ("obj", "window", "min_periods", "center", "dim", "keep_attrs") - _attributes = ("window", "min_periods", "center", "dim", "keep_attrs") + __slots__ = ("obj", "window", "min_periods", "center", "dim") + _attributes = ("window", "min_periods", "center", "dim") - def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None): + def __init__(self, obj, windows, min_periods=None, center=False): """ Moving window object. @@ -89,15 +89,6 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None self.min_periods = np.prod(self.window) if min_periods is None else min_periods - if keep_attrs is not None: - warnings.warn( - "Passing ``keep_attrs`` to ``rolling`` is deprecated and will raise an" - " error in xarray 0.18. Please pass ``keep_attrs`` directly to the" - " applied function. Note that keep_attrs is now True per default.", - FutureWarning, - ) - self.keep_attrs = keep_attrs - def __repr__(self): """provide a nice str repr of our rolling object""" @@ -188,15 +179,8 @@ def _mapping_to_list( ) def _get_keep_attrs(self, keep_attrs): - if keep_attrs is None: - # TODO: uncomment the next line and remove the others after the deprecation - # keep_attrs = _get_keep_attrs(default=True) - - if self.keep_attrs is None: - keep_attrs = _get_keep_attrs(default=True) - else: - keep_attrs = self.keep_attrs + keep_attrs = _get_keep_attrs(default=True) return keep_attrs @@ -204,7 +188,7 @@ def _get_keep_attrs(self, keep_attrs): class DataArrayRolling(Rolling): __slots__ = ("window_labels",) - def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None): + def __init__(self, obj, windows, min_periods=None, center=False): """ Moving window object for DataArray. You should use DataArray.rolling() method to construct this object @@ -235,9 +219,7 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None xarray.Dataset.rolling xarray.Dataset.groupby """ - super().__init__( - obj, windows, min_periods=min_periods, center=center, keep_attrs=keep_attrs - ) + super().__init__(obj, windows, min_periods=min_periods, center=center) # TODO legacy attribute self.window_labels = self.obj[self.dim[0]] @@ -561,7 +543,7 @@ def _numpy_or_bottleneck_reduce( class DatasetRolling(Rolling): __slots__ = ("rollings",) - def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None): + def __init__(self, obj, windows, min_periods=None, center=False): """ Moving window object for Dataset. You should use Dataset.rolling() method to construct this object @@ -592,7 +574,7 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None xarray.Dataset.groupby xarray.DataArray.groupby """ - super().__init__(obj, windows, min_periods, center, keep_attrs) + super().__init__(obj, windows, min_periods, center) if any(d not in self.obj.dims for d in self.dim): raise KeyError(self.dim) # Keep each Rolling object as a dictionary @@ -768,11 +750,10 @@ class Coarsen(CoarsenArithmetic): "windows", "side", "trim_excess", - "keep_attrs", ) _attributes = ("windows", "side", "trim_excess") - def __init__(self, obj, windows, boundary, side, coord_func, keep_attrs): + def __init__(self, obj, windows, boundary, side, coord_func): """ Moving window object. @@ -799,17 +780,6 @@ def __init__(self, obj, windows, boundary, side, coord_func, keep_attrs): self.side = side self.boundary = boundary - if keep_attrs is not None: - warnings.warn( - "Passing ``keep_attrs`` to ``coarsen`` is deprecated and will raise an" - " error in xarray 0.19. Please pass ``keep_attrs`` directly to the" - " applied function, i.e. use ``ds.coarsen(...).mean(keep_attrs=False)``" - " instead of ``ds.coarsen(..., keep_attrs=False).mean()``" - " Note that keep_attrs is now True per default.", - FutureWarning, - ) - self.keep_attrs = keep_attrs - absent_dims = [dim for dim in windows.keys() if dim not in self.obj.dims] if absent_dims: raise ValueError( @@ -823,15 +793,8 @@ def __init__(self, obj, windows, boundary, side, coord_func, keep_attrs): self.coord_func = coord_func def _get_keep_attrs(self, keep_attrs): - if keep_attrs is None: - # TODO: uncomment the next line and remove the others after the deprecation - # keep_attrs = _get_keep_attrs(default=True) - - if self.keep_attrs is None: - keep_attrs = _get_keep_attrs(default=True) - else: - keep_attrs = self.keep_attrs + keep_attrs = _get_keep_attrs(default=True) return keep_attrs diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 72e34932579..a139d2ef10a 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -10,6 +10,7 @@ import warnings from enum import Enum from typing import ( + TYPE_CHECKING, Any, Callable, Collection, @@ -32,12 +33,6 @@ import numpy as np import pandas as pd -if sys.version_info >= (3, 10): - from typing import TypeGuard -else: - from typing_extensions import TypeGuard - - K = TypeVar("K") V = TypeVar("V") T = TypeVar("T") @@ -297,11 +292,7 @@ def either_dict_or_kwargs( return pos_kwargs -def is_scalar(value: Any, include_0d: bool = True) -> TypeGuard[Hashable]: - """Whether to treat a value as a scalar. - - Any non-iterable, string, or 0-D array - """ +def _is_scalar(value, include_0d): from .variable import NON_NUMPY_SUPPORTED_ARRAY_TYPES if include_0d: @@ -316,6 +307,37 @@ def is_scalar(value: Any, include_0d: bool = True) -> TypeGuard[Hashable]: ) +# See GH5624, this is a convoluted way to allow type-checking to use `TypeGuard` without +# requiring typing_extensions as a required dependency to _run_ the code (it is required +# to type-check). +try: + if sys.version_info >= (3, 10): + from typing import TypeGuard + else: + from typing_extensions import TypeGuard +except ImportError: + if TYPE_CHECKING: + raise + else: + + def is_scalar(value: Any, include_0d: bool = True) -> bool: + """Whether to treat a value as a scalar. + + Any non-iterable, string, or 0-D array + """ + return _is_scalar(value, include_0d) + + +else: + + def is_scalar(value: Any, include_0d: bool = True) -> TypeGuard[Hashable]: + """Whether to treat a value as a scalar. + + Any non-iterable, string, or 0-D array + """ + return _is_scalar(value, include_0d) + + def is_valid_numpy_dtype(dtype: Any) -> bool: try: np.dtype(dtype) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index ace09c6f482..bd89fe97494 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -25,14 +25,22 @@ from . import common, dtypes, duck_array_ops, indexing, nputils, ops, utils from .arithmetic import VariableArithmetic from .common import AbstractArray -from .indexes import PandasIndex, wrap_pandas_index -from .indexing import BasicIndexer, OuterIndexer, VectorizedIndexer, as_indexable +from .indexes import PandasIndex, PandasMultiIndex +from .indexing import ( + BasicIndexer, + OuterIndexer, + PandasIndexingAdapter, + VectorizedIndexer, + as_indexable, +) from .options import _get_keep_attrs from .pycompat import ( + DuckArrayModule, cupy_array_type, dask_array_type, integer_types, is_duck_dask_array, + sparse_array_type, ) from .utils import ( NdimSizeLenMixin, @@ -116,14 +124,9 @@ def as_variable(obj, name=None) -> "Union[Variable, IndexVariable]": obj = obj.copy(deep=False) elif isinstance(obj, tuple): if isinstance(obj[1], DataArray): - # TODO: change into TypeError - warnings.warn( - ( - "Using a DataArray object to construct a variable is" - " ambiguous, please extract the data using the .data property." - " This will raise a TypeError in 0.19.0." - ), - DeprecationWarning, + raise TypeError( + "Using a DataArray object to construct a variable is" + " ambiguous, please extract the data using the .data property." ) try: obj = Variable(*obj) @@ -173,11 +176,11 @@ def _maybe_wrap_data(data): Put pandas.Index and numpy.ndarray arguments in adapter objects to ensure they can be indexed properly. - NumpyArrayAdapter, PandasIndex and LazilyIndexedArray should + NumpyArrayAdapter, PandasIndexingAdapter and LazilyIndexedArray should all pass through unmodified. """ if isinstance(data, pd.Index): - return wrap_pandas_index(data) + return PandasIndexingAdapter(data) return data @@ -259,7 +262,7 @@ def _as_array_or_item(data): TODO: remove this (replace with np.asarray) once these issues are fixed """ - data = data.get() if isinstance(data, cupy_array_type) else np.asarray(data) + data = np.asarray(data) if data.ndim == 0: if data.dtype.kind == "M": data = np.datetime64(data, "ns") @@ -334,7 +337,9 @@ def nbytes(self): @property def _in_memory(self): - return isinstance(self._data, (np.ndarray, np.number, PandasIndex)) or ( + return isinstance( + self._data, (np.ndarray, np.number, PandasIndexingAdapter) + ) or ( isinstance(self._data, indexing.MemoryCachedArray) and isinstance(self._data.array, indexing.NumpyIndexingAdapter) ) @@ -542,7 +547,14 @@ def to_index_variable(self): def _to_xindex(self): # temporary function used internally as a replacement of to_index() # returns an xarray Index instance instead of a pd.Index instance - return wrap_pandas_index(self.to_index()) + index_var = self.to_index_variable() + index = index_var.to_index() + dim = index_var.dims[0] + + if isinstance(index, pd.MultiIndex): + return PandasMultiIndex(index, dim) + else: + return PandasIndex(index, dim) def to_index(self): """Convert this variable to a pandas.Index""" @@ -1069,6 +1081,30 @@ def chunk(self, chunks={}, name=None, lock=False): return self._replace(data=data) + def to_numpy(self) -> np.ndarray: + """Coerces wrapped data to numpy and returns a numpy.ndarray""" + # TODO an entrypoint so array libraries can choose coercion method? + data = self.data + + # TODO first attempt to call .to_numpy() once some libraries implement it + if isinstance(data, dask_array_type): + data = data.compute() + if isinstance(data, cupy_array_type): + data = data.get() + # pint has to be imported dynamically as pint imports xarray + pint_array_type = DuckArrayModule("pint").type + if isinstance(data, pint_array_type): + data = data.magnitude + if isinstance(data, sparse_array_type): + data = data.todense() + data = np.asarray(data) + + return data + + def as_numpy(self: VariableType) -> VariableType: + """Coerces wrapped data into a numpy array, returning a Variable.""" + return self._replace(data=self.to_numpy()) + def _as_sparse(self, sparse_format=_default, fill_value=dtypes.NA): """ use sparse-array as backend. @@ -1378,7 +1414,11 @@ def roll(self, shifts=None, **shifts_kwargs): result = result._roll_one_dim(dim, count) return result - def transpose(self, *dims) -> "Variable": + def transpose( + self, + *dims, + missing_dims: str = "raise", + ) -> "Variable": """Return a new Variable object with transposed dimensions. Parameters @@ -1386,6 +1426,12 @@ def transpose(self, *dims) -> "Variable": *dims : str, optional By default, reverse the dimensions. Otherwise, reorder the dimensions to this order. + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + Variable: + - "raise": raise an exception + - "warn": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions Returns ------- @@ -1404,7 +1450,9 @@ def transpose(self, *dims) -> "Variable": """ if len(dims) == 0: dims = self.dims[::-1] - dims = tuple(infix_dims(dims, self.dims)) + else: + dims = tuple(infix_dims(dims, self.dims, missing_dims)) + if len(dims) < 2 or dims == self.dims: # no need to transpose if only one dimension # or dims are in same order @@ -2538,8 +2586,8 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): raise ValueError(f"{type(self).__name__} objects must be 1-dimensional") # Unlike in Variable, always eagerly load values into memory - if not isinstance(self._data, PandasIndex): - self._data = PandasIndex(self._data) + if not isinstance(self._data, PandasIndexingAdapter): + self._data = PandasIndexingAdapter(self._data) def __dask_tokenize__(self): from dask.base import normalize_token @@ -2874,7 +2922,7 @@ def assert_unique_multiindex_level_names(variables): level_names = defaultdict(list) all_level_names = set() for var_name, var in variables.items(): - if isinstance(var._data, PandasIndex): + if isinstance(var._data, PandasIndexingAdapter): idx_level_names = var.to_index_variable().level_names if idx_level_names is not None: for n in idx_level_names: diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index da7d523d28f..e20b6568e79 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -7,17 +7,22 @@ Dataset.plot._____ """ import functools +from distutils.version import LooseVersion import numpy as np import pandas as pd +from ..core.alignment import broadcast from .facetgrid import _easy_facetgrid from .utils import ( _add_colorbar, + _adjust_legend_subtitles, _assert_valid_xy, _ensure_plottable, _infer_interval_breaks, _infer_xy_labels, + _is_numeric, + _legend_add_subtitle, _process_cmap_cbar_kwargs, _rescale_imshow_rgb, _resolve_intervals_1dplot, @@ -26,8 +31,132 @@ get_axis, import_matplotlib_pyplot, label_from_attrs, + legend_elements, ) +# copied from seaborn +_MARKERSIZE_RANGE = np.array([18.0, 72.0]) + + +def _infer_scatter_metadata(darray, x, z, hue, hue_style, size): + def _determine_array(darray, name, array_style): + """Find and determine what type of array it is.""" + array = darray[name] + array_is_numeric = _is_numeric(array.values) + + if array_style is None: + array_style = "continuous" if array_is_numeric else "discrete" + elif array_style not in ["discrete", "continuous"]: + raise ValueError( + f"The style '{array_style}' is not valid, " + "valid options are None, 'discrete' or 'continuous'." + ) + + array_label = label_from_attrs(array) + + return array, array_style, array_label + + # Add nice looking labels: + out = dict(ylabel=label_from_attrs(darray)) + out.update( + { + k: label_from_attrs(darray[v]) if v in darray.coords else None + for k, v in [("xlabel", x), ("zlabel", z)] + } + ) + + # Add styles and labels for the dataarrays: + for type_, a, style in [("hue", hue, hue_style), ("size", size, None)]: + tp, stl, lbl = f"{type_}", f"{type_}_style", f"{type_}_label" + if a: + out[tp], out[stl], out[lbl] = _determine_array(darray, a, style) + else: + out[tp], out[stl], out[lbl] = None, None, None + + return out + + +# copied from seaborn +def _parse_size(data, norm, width): + """ + Determine what type of data it is. Then normalize it to width. + + If the data is categorical, normalize it to numbers. + """ + plt = import_matplotlib_pyplot() + + if data is None: + return None + + data = data.values.ravel() + + if not _is_numeric(data): + # Data is categorical. + # Use pd.unique instead of np.unique because that keeps + # the order of the labels: + levels = pd.unique(data) + numbers = np.arange(1, 1 + len(levels)) + else: + levels = numbers = np.sort(np.unique(data)) + + min_width, max_width = width + # width_range = min_width, max_width + + if norm is None: + norm = plt.Normalize() + elif isinstance(norm, tuple): + norm = plt.Normalize(*norm) + elif not isinstance(norm, plt.Normalize): + err = "``size_norm`` must be None, tuple, or Normalize object." + raise ValueError(err) + + norm.clip = True + if not norm.scaled(): + norm(np.asarray(numbers)) + # limits = norm.vmin, norm.vmax + + scl = norm(numbers) + widths = np.asarray(min_width + scl * (max_width - min_width)) + if scl.mask.any(): + widths[scl.mask] = 0 + sizes = dict(zip(levels, widths)) + + return pd.Series(sizes) + + +def _infer_scatter_data( + darray, x, z, hue, size, size_norm, size_mapping=None, size_range=(1, 10) +): + # Broadcast together all the chosen variables: + to_broadcast = dict(y=darray) + to_broadcast.update( + {k: darray[v] for k, v in dict(x=x, z=z).items() if v is not None} + ) + to_broadcast.update( + {k: darray[v] for k, v in dict(hue=hue, size=size).items() if v in darray.dims} + ) + broadcasted = dict(zip(to_broadcast.keys(), broadcast(*(to_broadcast.values())))) + + # Normalize hue and size and create lookup tables: + for type_, mapping, norm, width in [ + ("hue", None, None, [0, 1]), + ("size", size_mapping, size_norm, size_range), + ]: + broadcasted_type = broadcasted.get(type_, None) + if broadcasted_type is not None: + if mapping is None: + mapping = _parse_size(broadcasted_type, norm, width) + + broadcasted[type_] = broadcasted_type.copy( + data=np.reshape( + mapping.loc[broadcasted_type.values.ravel()].values, + broadcasted_type.shape, + ) + ) + broadcasted[f"{type_}_to_label"] = pd.Series(mapping.index, index=mapping) + + return broadcasted + def _infer_line_data(darray, x, y, hue): @@ -301,7 +430,7 @@ def line( # Remove pd.Intervals if contained in xplt.values and/or yplt.values. xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot( - xplt.values, yplt.values, kwargs + xplt.to_numpy(), yplt.to_numpy(), kwargs ) xlabel = label_from_attrs(xplt, extra=x_suffix) ylabel = label_from_attrs(yplt, extra=y_suffix) @@ -320,7 +449,7 @@ def line( ax.set_title(darray._title_for_slice()) if darray.ndim == 2 and add_legend: - ax.legend(handles=primitive, labels=list(hueplt.values), title=hue_label) + ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label) # Rotate dates on xlabels # Do this without calling autofmt_xdate so that x-axes ticks @@ -422,7 +551,7 @@ def hist( """ ax = get_axis(figsize, size, aspect, ax) - no_nan = np.ravel(darray.values) + no_nan = np.ravel(darray.to_numpy()) no_nan = no_nan[pd.notnull(no_nan)] primitive = ax.hist(no_nan, **kwargs) @@ -435,6 +564,291 @@ def hist( return primitive +def scatter( + darray, + *args, + row=None, + col=None, + figsize=None, + aspect=None, + size=None, + ax=None, + hue=None, + hue_style=None, + x=None, + z=None, + xincrease=None, + yincrease=None, + xscale=None, + yscale=None, + xticks=None, + yticks=None, + xlim=None, + ylim=None, + add_legend=None, + add_colorbar=None, + cbar_kwargs=None, + cbar_ax=None, + vmin=None, + vmax=None, + norm=None, + infer_intervals=None, + center=None, + levels=None, + robust=None, + colors=None, + extend=None, + cmap=None, + _labels=True, + **kwargs, +): + """ + Scatter plot a DataArray along some coordinates. + + Parameters + ---------- + darray : DataArray + Dataarray to plot. + x, y : str + Variable names for x, y axis. + hue: str, optional + Variable by which to color scattered points + hue_style: str, optional + Can be either 'discrete' (legend) or 'continuous' (color bar). + markersize: str, optional + scatter only. Variable by which to vary size of scattered points. + size_norm: optional + Either None or 'Norm' instance to normalize the 'markersize' variable. + add_guide: bool, optional + Add a guide that depends on hue_style + - for "discrete", build a legend. + This is the default for non-numeric `hue` variables. + - for "continuous", build a colorbar + row : str, optional + If passed, make row faceted plots on this dimension name + col : str, optional + If passed, make column faceted plots on this dimension name + col_wrap : int, optional + Use together with ``col`` to wrap faceted plots + ax : matplotlib axes object, optional + If None, uses the current axis. Not applicable when using facets. + subplot_kws : dict, optional + Dictionary of keyword arguments for matplotlib subplots. Only applies + to FacetGrid plotting. + aspect : scalar, optional + Aspect ratio of plot, so that ``aspect * size`` gives the width in + inches. Only used if a ``size`` is provided. + size : scalar, optional + If provided, create a new figure for the plot with the given size. + Height (in inches) of each plot. See also: ``aspect``. + norm : ``matplotlib.colors.Normalize`` instance, optional + If the ``norm`` has vmin or vmax specified, the corresponding kwarg + must be None. + vmin, vmax : float, optional + Values to anchor the colormap, otherwise they are inferred from the + data and other keyword arguments. When a diverging dataset is inferred, + setting one of these values will fix the other by symmetry around + ``center``. Setting both values prevents use of a diverging colormap. + If discrete levels are provided as an explicit list, both of these + values are ignored. + cmap : str or colormap, optional + The mapping from data values to color space. Either a + matplotlib colormap name or object. If not provided, this will + be either ``viridis`` (if the function infers a sequential + dataset) or ``RdBu_r`` (if the function infers a diverging + dataset). When `Seaborn` is installed, ``cmap`` may also be a + `seaborn` color palette. If ``cmap`` is seaborn color palette + and the plot type is not ``contour`` or ``contourf``, ``levels`` + must also be specified. + colors : color-like or list of color-like, optional + A single color or a list of colors. If the plot type is not ``contour`` + or ``contourf``, the ``levels`` argument is required. + center : float, optional + The value at which to center the colormap. Passing this value implies + use of a diverging colormap. Setting it to ``False`` prevents use of a + diverging colormap. + robust : bool, optional + If True and ``vmin`` or ``vmax`` are absent, the colormap range is + computed with 2nd and 98th percentiles instead of the extreme values. + extend : {"neither", "both", "min", "max"}, optional + How to draw arrows extending the colorbar beyond its limits. If not + provided, extend is inferred from vmin, vmax and the data limits. + levels : int or list-like object, optional + Split the colormap (cmap) into discrete color intervals. If an integer + is provided, "nice" levels are chosen based on the data range: this can + imply that the final number of levels is not exactly the expected one. + Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to + setting ``levels=np.linspace(vmin, vmax, N)``. + **kwargs : optional + Additional keyword arguments to matplotlib + """ + plt = import_matplotlib_pyplot() + + # Handle facetgrids first + if row or col: + allargs = locals().copy() + allargs.update(allargs.pop("kwargs")) + allargs.pop("darray") + subplot_kws = dict(projection="3d") if z is not None else None + return _easy_facetgrid( + darray, scatter, kind="dataarray", subplot_kws=subplot_kws, **allargs + ) + + # Further + _is_facetgrid = kwargs.pop("_is_facetgrid", False) + if _is_facetgrid: + # Why do I need to pop these here? + kwargs.pop("y", None) + kwargs.pop("args", None) + kwargs.pop("add_labels", None) + + _sizes = kwargs.pop("markersize", kwargs.pop("linewidth", None)) + size_norm = kwargs.pop("size_norm", None) + size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid + cmap_params = kwargs.pop("cmap_params", None) + + figsize = kwargs.pop("figsize", None) + subplot_kws = dict() + if z is not None and ax is None: + # TODO: Importing Axes3D is not necessary in matplotlib >= 3.2. + # Remove when minimum requirement of matplotlib is 3.2: + from mpl_toolkits.mplot3d import Axes3D # type: ignore # noqa + + subplot_kws.update(projection="3d") + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + # Using 30, 30 minimizes rotation of the plot. Making it easier to + # build on your intuition from 2D plots: + if LooseVersion(plt.matplotlib.__version__) < "3.5.0": + ax.view_init(azim=30, elev=30) + else: + # https://github.com/matplotlib/matplotlib/pull/19873 + ax.view_init(azim=30, elev=30, vertical_axis="y") + else: + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + + _data = _infer_scatter_metadata(darray, x, z, hue, hue_style, _sizes) + + add_guide = kwargs.pop("add_guide", None) + if add_legend is not None: + pass + elif add_guide is None or add_guide is True: + add_legend = True if _data["hue_style"] == "discrete" else False + elif add_legend is None: + add_legend = False + + if add_colorbar is not None: + pass + elif add_guide is None or add_guide is True: + add_colorbar = True if _data["hue_style"] == "continuous" else False + else: + add_colorbar = False + + # need to infer size_mapping with full dataset + _data.update( + _infer_scatter_data( + darray, + x, + z, + hue, + _sizes, + size_norm, + size_mapping, + _MARKERSIZE_RANGE, + ) + ) + + cmap_params_subset = {} + if _data["hue"] is not None: + kwargs.update(c=_data["hue"].values.ravel()) + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + scatter, _data["hue"].values, **locals() + ) + + # subset that can be passed to scatter, hist2d + cmap_params_subset = { + vv: cmap_params[vv] for vv in ["vmin", "vmax", "norm", "cmap"] + } + + if _data["size"] is not None: + kwargs.update(s=_data["size"].values.ravel()) + + if LooseVersion(plt.matplotlib.__version__) < "3.5.0": + # Plot the data. 3d plots has the z value in upward direction + # instead of y. To make jumping between 2d and 3d easy and intuitive + # switch the order so that z is shown in the depthwise direction: + axis_order = ["x", "z", "y"] + else: + # Switching axis order not needed in 3.5.0, can also simplify the code + # that uses axis_order: + # https://github.com/matplotlib/matplotlib/pull/19873 + axis_order = ["x", "y", "z"] + + primitive = ax.scatter( + *[ + _data[v].values.ravel() + for v in axis_order + if _data.get(v, None) is not None + ], + **cmap_params_subset, + **kwargs, + ) + + # Set x, y, z labels: + i = 0 + set_label = [ax.set_xlabel, ax.set_ylabel, getattr(ax, "set_zlabel", None)] + for v in axis_order: + if _data.get(f"{v}label", None) is not None: + set_label[i](_data[f"{v}label"]) + i += 1 + + if add_legend: + + def to_label(data, key, x): + """Map prop values back to its original values.""" + if key in data: + # Use reindex to be less sensitive to float errors. + # Return as numpy array since legend_elements + # seems to require that: + return data[key].reindex(x, method="nearest").to_numpy() + else: + return x + + handles, labels = [], [] + for subtitle, prop, func in [ + ( + _data["hue_label"], + "colors", + functools.partial(to_label, _data, "hue_to_label"), + ), + ( + _data["size_label"], + "sizes", + functools.partial(to_label, _data, "size_to_label"), + ), + ]: + if subtitle: + # Get legend handles and labels that displays the + # values correctly. Order might be different because + # legend_elements uses np.unique instead of pd.unique, + # FacetGrid.add_legend might have troubles with this: + hdl, lbl = legend_elements(primitive, prop, num="auto", func=func) + hdl, lbl = _legend_add_subtitle(hdl, lbl, subtitle, ax.scatter) + handles += hdl + labels += lbl + legend = ax.legend(handles, labels, framealpha=0.5) + _adjust_legend_subtitles(legend) + + if add_colorbar and _data["hue_label"]: + if _data["hue_style"] == "discrete": + raise NotImplementedError("Cannot create a colorbar for non numerics.") + cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs + if "label" not in cbar_kwargs: + cbar_kwargs["label"] = _data["hue_label"] + _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) + + return primitive + + # MUST run before any 2d plotting functions are defined since # _plot2d decorator adds them as methods here. class _PlotMethods: @@ -468,6 +882,10 @@ def line(self, *args, **kwargs): def step(self, *args, **kwargs): return step(self._da, *args, **kwargs) + @functools.wraps(scatter) + def _scatter(self, *args, **kwargs): + return scatter(self._da, *args, **kwargs) + def override_signature(f): def wrapper(func): @@ -735,8 +1153,8 @@ def newplotfunc( dims = (yval.dims[0], xval.dims[0]) # better to pass the ndarrays directly to plotting functions - xval = xval.values - yval = yval.values + xval = xval.to_numpy() + yval = yval.to_numpy() # May need to transpose for correct x, y labels # xlab may be the name of a coord, we have to check for dim names diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 85f9c8c5a86..f2f296096a5 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -9,6 +9,7 @@ import pandas as pd from ..core.options import OPTIONS +from ..core.pycompat import DuckArrayModule from ..core.utils import is_scalar try: @@ -474,12 +475,20 @@ def label_from_attrs(da, extra=""): else: name = "" - if da.attrs.get("units"): - units = " [{}]".format(da.attrs["units"]) - elif da.attrs.get("unit"): - units = " [{}]".format(da.attrs["unit"]) + def _get_units_from_attrs(da): + if da.attrs.get("units"): + units = " [{}]".format(da.attrs["units"]) + elif da.attrs.get("unit"): + units = " [{}]".format(da.attrs["unit"]) + else: + units = "" + return units + + pint_array_type = DuckArrayModule("pint").type + if isinstance(da.data, pint_array_type): + units = " [{}]".format(str(da.data.units)) else: - units = "" + units = _get_units_from_attrs(da) return "\n".join(textwrap.wrap(name + extra + units, 30)) @@ -896,6 +905,234 @@ def _get_nice_quiver_magnitude(u, v): import matplotlib as mpl ticker = mpl.ticker.MaxNLocator(3) - mean = np.mean(np.hypot(u.values, v.values)) + mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy())) magnitude = ticker.tick_values(0, mean)[-2] return magnitude + + +# Copied from matplotlib, tweaked so func can return strings. +# https://github.com/matplotlib/matplotlib/issues/19555 +def legend_elements( + self, prop="colors", num="auto", fmt=None, func=lambda x: x, **kwargs +): + """ + Create legend handles and labels for a PathCollection. + + Each legend handle is a `.Line2D` representing the Path that was drawn, + and each label is a string what each Path represents. + + This is useful for obtaining a legend for a `~.Axes.scatter` plot; + e.g.:: + + scatter = plt.scatter([1, 2, 3], [4, 5, 6], c=[7, 2, 3]) + plt.legend(*scatter.legend_elements()) + + creates three legend elements, one for each color with the numerical + values passed to *c* as the labels. + + Also see the :ref:`automatedlegendcreation` example. + + + Parameters + ---------- + prop : {"colors", "sizes"}, default: "colors" + If "colors", the legend handles will show the different colors of + the collection. If "sizes", the legend will show the different + sizes. To set both, use *kwargs* to directly edit the `.Line2D` + properties. + num : int, None, "auto" (default), array-like, or `~.ticker.Locator` + Target number of elements to create. + If None, use all unique elements of the mappable array. If an + integer, target to use *num* elements in the normed range. + If *"auto"*, try to determine which option better suits the nature + of the data. + The number of created elements may slightly deviate from *num* due + to a `~.ticker.Locator` being used to find useful locations. + If a list or array, use exactly those elements for the legend. + Finally, a `~.ticker.Locator` can be provided. + fmt : str, `~matplotlib.ticker.Formatter`, or None (default) + The format or formatter to use for the labels. If a string must be + a valid input for a `~.StrMethodFormatter`. If None (the default), + use a `~.ScalarFormatter`. + func : function, default: ``lambda x: x`` + Function to calculate the labels. Often the size (or color) + argument to `~.Axes.scatter` will have been pre-processed by the + user using a function ``s = f(x)`` to make the markers visible; + e.g. ``size = np.log10(x)``. Providing the inverse of this + function here allows that pre-processing to be inverted, so that + the legend labels have the correct values; e.g. ``func = lambda + x: 10**x``. + **kwargs + Allowed keyword arguments are *color* and *size*. E.g. it may be + useful to set the color of the markers if *prop="sizes"* is used; + similarly to set the size of the markers if *prop="colors"* is + used. Any further parameters are passed onto the `.Line2D` + instance. This may be useful to e.g. specify a different + *markeredgecolor* or *alpha* for the legend handles. + + Returns + ------- + handles : list of `.Line2D` + Visual representation of each element of the legend. + labels : list of str + The string labels for elements of the legend. + """ + import warnings + + import matplotlib as mpl + + mlines = mpl.lines + + handles = [] + labels = [] + + if prop == "colors": + arr = self.get_array() + if arr is None: + warnings.warn( + "Collection without array used. Make sure to " + "specify the values to be colormapped via the " + "`c` argument." + ) + return handles, labels + _size = kwargs.pop("size", mpl.rcParams["lines.markersize"]) + + def _get_color_and_size(value): + return self.cmap(self.norm(value)), _size + + elif prop == "sizes": + arr = self.get_sizes() + _color = kwargs.pop("color", "k") + + def _get_color_and_size(value): + return _color, np.sqrt(value) + + else: + raise ValueError( + "Valid values for `prop` are 'colors' or " + f"'sizes'. You supplied '{prop}' instead." + ) + + # Get the unique values and their labels: + values = np.unique(arr) + label_values = np.asarray(func(values)) + label_values_are_numeric = np.issubdtype(label_values.dtype, np.number) + + # Handle the label format: + if fmt is None and label_values_are_numeric: + fmt = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True) + elif fmt is None and not label_values_are_numeric: + fmt = mpl.ticker.StrMethodFormatter("{x}") + elif isinstance(fmt, str): + fmt = mpl.ticker.StrMethodFormatter(fmt) + fmt.create_dummy_axis() + + if num == "auto": + num = 9 + if len(values) <= num: + num = None + + if label_values_are_numeric: + label_values_min = label_values.min() + label_values_max = label_values.max() + fmt.set_bounds(label_values_min, label_values_max) + + if num is not None: + # Labels are numerical but larger than the target + # number of elements, reduce to target using matplotlibs + # ticker classes: + if isinstance(num, mpl.ticker.Locator): + loc = num + elif np.iterable(num): + loc = mpl.ticker.FixedLocator(num) + else: + num = int(num) + loc = mpl.ticker.MaxNLocator( + nbins=num, min_n_ticks=num - 1, steps=[1, 2, 2.5, 3, 5, 6, 8, 10] + ) + + # Get nicely spaced label_values: + label_values = loc.tick_values(label_values_min, label_values_max) + + # Remove extrapolated label_values: + cond = (label_values >= label_values_min) & ( + label_values <= label_values_max + ) + label_values = label_values[cond] + + # Get the corresponding values by creating a linear interpolant + # with small step size: + values_interp = np.linspace(values.min(), values.max(), 256) + label_values_interp = func(values_interp) + ix = np.argsort(label_values_interp) + values = np.interp(label_values, label_values_interp[ix], values_interp[ix]) + elif num is not None and not label_values_are_numeric: + # Labels are not numerical so modifying label_values is not + # possible, instead filter the array with nicely distributed + # indexes: + if type(num) == int: + loc = mpl.ticker.LinearLocator(num) + else: + raise ValueError("`num` only supports integers for non-numeric labels.") + + ind = loc.tick_values(0, len(label_values) - 1).astype(int) + label_values = label_values[ind] + values = values[ind] + + # Some formatters requires set_locs: + if hasattr(fmt, "set_locs"): + fmt.set_locs(label_values) + + # Default settings for handles, add or override with kwargs: + kw = dict(markeredgewidth=self.get_linewidths()[0], alpha=self.get_alpha()) + kw.update(kwargs) + + for val, lab in zip(values, label_values): + color, size = _get_color_and_size(val) + h = mlines.Line2D( + [0], [0], ls="", color=color, ms=size, marker=self.get_paths()[0], **kw + ) + handles.append(h) + labels.append(fmt(lab)) + + return handles, labels + + +def _legend_add_subtitle(handles, labels, text, func): + """Add a subtitle to legend handles.""" + if text and len(handles) > 1: + # Create a blank handle that's not visible, the + # invisibillity will be used to discern which are subtitles + # or not: + blank_handle = func([], [], label=text) + blank_handle.set_visible(False) + + # Subtitles are shown first: + handles = [blank_handle] + handles + labels = [text] + labels + + return handles, labels + + +def _adjust_legend_subtitles(legend): + """Make invisible-handle "subtitles" entries look more like titles.""" + plt = import_matplotlib_pyplot() + + # Legend title not in rcParams until 3.0 + font_size = plt.rcParams.get("legend.title_fontsize", None) + hpackers = legend.findobj(plt.matplotlib.offsetbox.VPacker)[0].get_children() + for hpack in hpackers: + draw_area, text_area = hpack.get_children() + handles = draw_area.get_children() + + # Assume that all artists that are not visible are + # subtitles: + if not all(artist.get_visible() for artist in handles): + # Remove the dummy marker which will bring the text + # more to the center: + draw_area.set_width(0) + for text in text_area.get_children(): + if font_size is not None: + # The sutbtitles should have the same font size + # as normal legend titles: + text.set_size(font_size) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 9029dc1c621..d757fb451cc 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -83,6 +83,7 @@ def LooseVersion(vstring): has_numbagg, requires_numbagg = _importorskip("numbagg") has_seaborn, requires_seaborn = _importorskip("seaborn") has_sparse, requires_sparse = _importorskip("sparse") +has_cupy, requires_cupy = _importorskip("cupy") has_cartopy, requires_cartopy = _importorskip("cartopy") # Need Pint 0.15 for __dask_tokenize__ tests for Quantity wrapped Dask Arrays has_pint_0_15, requires_pint_0_15 = _importorskip("pint", minversion="0.15") diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 5079cd390f1..3bbc2c93b31 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -42,7 +42,7 @@ from xarray.backends.scipy_ import ScipyBackendEntrypoint from xarray.coding.variables import SerializationWarning from xarray.conventions import encode_dataset_coordinates -from xarray.core import indexes, indexing +from xarray.core import indexing from xarray.core.options import set_options from xarray.core.pycompat import dask_array_type from xarray.tests import LooseVersion, mock @@ -87,12 +87,8 @@ try: import dask import dask.array as da - - dask_version = dask.__version__ except ImportError: - # needed for xfailed tests when dask < 2.4.0 - # remove when min dask > 2.4.0 - dask_version = "10.0" + pass ON_WINDOWS = sys.platform == "win32" default_value = object() @@ -742,7 +738,7 @@ def find_and_validate_array(obj): elif isinstance(obj.array, dask_array_type): assert isinstance(obj, indexing.DaskIndexingAdapter) elif isinstance(obj.array, pd.Index): - assert isinstance(obj, indexes.PandasIndex) + assert isinstance(obj, indexing.PandasIndexingAdapter) else: raise TypeError( "{} is wrapped by {}".format(type(obj.array), type(obj)) @@ -1961,7 +1957,6 @@ def test_hidden_zarr_keys(self): with xr.decode_cf(store): pass - @pytest.mark.skipif(LooseVersion(dask_version) < "2.4", reason="dask GH5334") @pytest.mark.parametrize("group", [None, "group1"]) def test_write_persistence_modes(self, group): original = create_test_data() @@ -2039,7 +2034,6 @@ def test_encoding_kwarg_fixed_width_string(self): def test_dataset_caching(self): super().test_dataset_caching() - @pytest.mark.skipif(LooseVersion(dask_version) < "2.4", reason="dask GH5334") def test_append_write(self): super().test_append_write() @@ -2122,7 +2116,6 @@ def test_check_encoding_is_consistent_after_append(self): xr.concat([ds, ds_to_append], dim="time"), ) - @pytest.mark.skipif(LooseVersion(dask_version) < "2.4", reason="dask GH5334") def test_append_with_new_variable(self): ds, ds_to_append, ds_with_new_var = create_append_test_data() @@ -2777,6 +2770,7 @@ def test_dump_encodings_h5py(self): @requires_h5netcdf +@requires_netCDF4 class TestH5NetCDFAlreadyOpen: def test_open_dataset_group(self): import h5netcdf @@ -2861,6 +2855,7 @@ def test_open_twice(self): with open_dataset(f, engine="h5netcdf"): pass + @requires_scipy def test_open_fileobj(self): # open in-memory datasets instead of local file paths expected = create_test_data().drop_vars("dim3") @@ -5162,11 +5157,12 @@ def test_open_fsspec(): @requires_h5netcdf +@requires_netCDF4 def test_load_single_value_h5netcdf(tmp_path): """Test that numeric single-element vector attributes are handled fine. At present (h5netcdf v0.8.1), the h5netcdf exposes single-valued numeric variable - attributes as arrays of length 1, as oppesed to scalars for the NetCDF4 + attributes as arrays of length 1, as opposed to scalars for the NetCDF4 backend. This was leading to a ValueError upon loading a single value from a file, see #4471. Test that loading causes no failure. """ diff --git a/xarray/tests/test_coarsen.py b/xarray/tests/test_coarsen.py index 503c742252a..278a961166f 100644 --- a/xarray/tests/test_coarsen.py +++ b/xarray/tests/test_coarsen.py @@ -153,39 +153,6 @@ def test_coarsen_keep_attrs(funcname, argument): assert result.da_not_coarsend.name == "da_not_coarsend" -def test_coarsen_keep_attrs_deprecated(): - global_attrs = {"units": "test", "long_name": "testing"} - attrs_da = {"da_attr": "test"} - - data = np.linspace(10, 15, 100) - coords = np.linspace(1, 10, 100) - - ds = Dataset( - data_vars={"da": ("coord", data)}, - coords={"coord": coords}, - attrs=global_attrs, - ) - ds.da.attrs = attrs_da - - # deprecated option - with pytest.warns( - FutureWarning, match="Passing ``keep_attrs`` to ``coarsen`` is deprecated" - ): - result = ds.coarsen(dim={"coord": 5}, keep_attrs=False).mean() - - assert result.attrs == {} - assert result.da.attrs == {} - - # the keep_attrs in the reduction function takes precedence - with pytest.warns( - FutureWarning, match="Passing ``keep_attrs`` to ``coarsen`` is deprecated" - ): - result = ds.coarsen(dim={"coord": 5}, keep_attrs=True).mean(keep_attrs=False) - - assert result.attrs == {} - assert result.da.attrs == {} - - @pytest.mark.slow @pytest.mark.parametrize("ds", (1, 2), indirect=True) @pytest.mark.parametrize("window", (1, 2, 3, 4)) @@ -267,31 +234,6 @@ def test_coarsen_da_keep_attrs(funcname, argument): assert result.name == "name" -def test_coarsen_da_keep_attrs_deprecated(): - attrs_da = {"da_attr": "test"} - - data = np.linspace(10, 15, 100) - coords = np.linspace(1, 10, 100) - - da = DataArray(data, dims=("coord"), coords={"coord": coords}, attrs=attrs_da) - - # deprecated option - with pytest.warns( - FutureWarning, match="Passing ``keep_attrs`` to ``coarsen`` is deprecated" - ): - result = da.coarsen(dim={"coord": 5}, keep_attrs=False).mean() - - assert result.attrs == {} - - # the keep_attrs in the reduction function takes precedence - with pytest.warns( - FutureWarning, match="Passing ``keep_attrs`` to ``coarsen`` is deprecated" - ): - result = da.coarsen(dim={"coord": 5}, keep_attrs=True).mean(keep_attrs=False) - - assert result.attrs == {} - - @pytest.mark.parametrize("da", (1, 2), indirect=True) @pytest.mark.parametrize("window", (1, 2, 3, 4)) @pytest.mark.parametrize("name", ("sum", "mean", "std", "max")) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 09bed72496b..2439ea30b4b 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1,7 +1,6 @@ import functools import operator import pickle -from distutils.version import LooseVersion import numpy as np import pandas as pd @@ -21,6 +20,7 @@ result_name, unified_dim_sizes, ) +from xarray.core.pycompat import dask_version from . import has_dask, raise_if_dask_computes, requires_dask @@ -1307,7 +1307,7 @@ def test_vectorize_dask_dtype_without_output_dtypes(data_array): @pytest.mark.skipif( - LooseVersion(dask.__version__) > "2021.06", + dask_version > "2021.06", reason="dask/dask#7669: can no longer pass output_dtypes and meta", ) @requires_dask diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index f790587efa9..d5d460056aa 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -2,7 +2,6 @@ import pickle import sys from contextlib import suppress -from distutils.version import LooseVersion from textwrap import dedent import numpy as np @@ -13,6 +12,7 @@ import xarray.ufuncs as xu from xarray import DataArray, Dataset, Variable from xarray.core import duck_array_ops +from xarray.core.pycompat import dask_version from xarray.testing import assert_chunks_equal from xarray.tests import mock @@ -111,10 +111,7 @@ def test_indexing(self): self.assertLazyAndIdentical(u[:1], v[:1]) self.assertLazyAndIdentical(u[[0, 1], [0, 1, 2]], v[[0, 1], [0, 1, 2]]) - @pytest.mark.skipif( - LooseVersion(dask.__version__) < LooseVersion("2021.04.1"), - reason="Requires dask v2021.04.1 or later", - ) + @pytest.mark.skipif(dask_version < "2021.04.1", reason="Requires dask >= 2021.04.1") @pytest.mark.parametrize( "expected_data, index", [ @@ -133,10 +130,7 @@ def test_setitem_dask_array(self, expected_data, index): arr[index] = 99 assert_identical(arr, expected) - @pytest.mark.skipif( - LooseVersion(dask.__version__) >= LooseVersion("2021.04.1"), - reason="Requires dask v2021.04.0 or earlier", - ) + @pytest.mark.skipif(dask_version >= "2021.04.1", reason="Requires dask < 2021.04.1") def test_setitem_dask_array_error(self): with pytest.raises(TypeError, match=r"stored in a dask array"): v = self.lazy_var @@ -612,25 +606,6 @@ def test_dot(self): lazy = self.lazy_array.dot(self.lazy_array[0]) self.assertLazyAndAllClose(eager, lazy) - @pytest.mark.skipif(LooseVersion(dask.__version__) >= "2.0", reason="no meta") - def test_dataarray_repr_legacy(self): - data = build_dask_array("data") - nonindex_coord = build_dask_array("coord") - a = DataArray(data, dims=["x"], coords={"y": ("x", nonindex_coord)}) - expected = dedent( - """\ - - {!r} - Coordinates: - y (x) int64 dask.array - Dimensions without coordinates: x""".format( - data - ) - ) - assert expected == repr(a) - assert kernel_call_count == 0 # should not evaluate dask array - - @pytest.mark.skipif(LooseVersion(dask.__version__) < "2.0", reason="needs meta") def test_dataarray_repr(self): data = build_dask_array("data") nonindex_coord = build_dask_array("coord") @@ -648,7 +623,6 @@ def test_dataarray_repr(self): assert expected == repr(a) assert kernel_call_count == 0 # should not evaluate dask array - @pytest.mark.skipif(LooseVersion(dask.__version__) < "2.0", reason="needs meta") def test_dataset_repr(self): data = build_dask_array("data") nonindex_coord = build_dask_array("coord") @@ -1619,7 +1593,7 @@ def test_more_transforms_pass_lazy_array_equiv(map_da, map_ds): assert_equal(xr.broadcast(map_ds.cxy, map_ds.cxy)[0], map_ds.cxy) assert_equal(map_ds.map(lambda x: x), map_ds) assert_equal(map_ds.set_coords("a").reset_coords("a"), map_ds) - assert_equal(map_ds.update({"a": map_ds.a}), map_ds) + assert_equal(map_ds.assign({"a": map_ds.a}), map_ds) # fails because of index error # assert_equal( @@ -1645,7 +1619,7 @@ def test_optimize(): # The graph_manipulation module is in dask since 2021.2 but it became usable with # xarray only since 2021.3 -@pytest.mark.skipif(LooseVersion(dask.__version__) <= "2021.02.0", reason="new module") +@pytest.mark.skipif(dask_version <= "2021.02.0", reason="new module") def test_graph_manipulation(): """dask.graph_manipulation passes an optional parameter, "rename", to the rebuilder function returned by __dask_postperist__; also, the dsk passed to the rebuilder is diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index b9f04085935..8ab8bc872da 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -36,10 +36,12 @@ has_dask, raise_if_dask_computes, requires_bottleneck, + requires_cupy, requires_dask, requires_iris, requires_numbagg, requires_numexpr, + requires_pint_0_15, requires_scipy, requires_sparse, source_ndarray, @@ -148,7 +150,9 @@ def test_data_property(self): def test_indexes(self): array = DataArray(np.zeros((2, 3)), [("x", [0, 1]), ("y", ["a", "b", "c"])]) expected_indexes = {"x": pd.Index([0, 1]), "y": pd.Index(["a", "b", "c"])} - expected_xindexes = {k: PandasIndex(idx) for k, idx in expected_indexes.items()} + expected_xindexes = { + k: PandasIndex(idx, k) for k, idx in expected_indexes.items() + } assert array.xindexes.keys() == expected_xindexes.keys() assert array.indexes.keys() == expected_indexes.keys() assert all([isinstance(idx, pd.Index) for idx in array.indexes.values()]) @@ -1471,7 +1475,7 @@ def test_coords_alignment(self): def test_set_coords_update_index(self): actual = DataArray([1, 2, 3], [("x", [1, 2, 3])]) actual.coords["x"] = ["a", "b", "c"] - assert actual.xindexes["x"].equals(pd.Index(["a", "b", "c"])) + assert actual.xindexes["x"].to_pandas_index().equals(pd.Index(["a", "b", "c"])) def test_coords_replacement_alignment(self): # regression test for GH725 @@ -1635,15 +1639,6 @@ def test_init_value(self): DataArray(np.array(1), coords=[("x", np.arange(10))]) def test_swap_dims(self): - array = DataArray(np.random.randn(3), {"y": ("x", list("abc"))}, "x") - expected = DataArray(array.values, {"y": list("abc")}, dims="y") - actual = array.swap_dims({"x": "y"}) - assert_identical(expected, actual) - for dim_name in set().union(expected.xindexes.keys(), actual.xindexes.keys()): - pd.testing.assert_index_equal( - expected.xindexes[dim_name].array, actual.xindexes[dim_name].array - ) - array = DataArray(np.random.randn(3), {"x": list("abc")}, "x") expected = DataArray(array.values, {"x": ("y", list("abc"))}, dims="y") actual = array.swap_dims({"x": "y"}) @@ -6865,33 +6860,6 @@ def test_rolling_keep_attrs(funcname, argument): assert result.name == "name" -def test_rolling_keep_attrs_deprecated(): - attrs_da = {"da_attr": "test"} - - data = np.linspace(10, 15, 100) - coords = np.linspace(1, 10, 100) - - da = DataArray(data, dims=("coord"), coords={"coord": coords}, attrs=attrs_da) - - # deprecated option - with pytest.warns( - FutureWarning, match="Passing ``keep_attrs`` to ``rolling`` is deprecated" - ): - result = da.rolling(dim={"coord": 5}, keep_attrs=False).construct("window_dim") - - assert result.attrs == {} - - # the keep_attrs in the reduction function takes precedence - with pytest.warns( - FutureWarning, match="Passing ``keep_attrs`` to ``rolling`` is deprecated" - ): - result = da.rolling(dim={"coord": 5}, keep_attrs=True).construct( - "window_dim", keep_attrs=False - ) - - assert result.attrs == {} - - def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: xr.DataArray([1, 2, np.NaN]) > 0 @@ -7375,3 +7343,87 @@ def test_drop_duplicates(keep): expected = xr.DataArray(data, dims="time", coords={"time": time}, name="test") result = ds.drop_duplicates("time", keep=keep) assert_equal(expected, result) + + +class TestNumpyCoercion: + # TODO once flexible indexes refactor complete also test coercion of dimension coords + def test_from_numpy(self): + da = xr.DataArray([1, 2, 3], dims="x", coords={"lat": ("x", [4, 5, 6])}) + + assert_identical(da.as_numpy(), da) + np.testing.assert_equal(da.to_numpy(), np.array([1, 2, 3])) + np.testing.assert_equal(da["lat"].to_numpy(), np.array([4, 5, 6])) + + @requires_dask + def test_from_dask(self): + da = xr.DataArray([1, 2, 3], dims="x", coords={"lat": ("x", [4, 5, 6])}) + da_chunked = da.chunk(1) + + assert_identical(da_chunked.as_numpy(), da.compute()) + np.testing.assert_equal(da.to_numpy(), np.array([1, 2, 3])) + np.testing.assert_equal(da["lat"].to_numpy(), np.array([4, 5, 6])) + + @requires_pint_0_15 + def test_from_pint(self): + from pint import Quantity + + arr = np.array([1, 2, 3]) + da = xr.DataArray( + Quantity(arr, units="Pa"), + dims="x", + coords={"lat": ("x", Quantity(arr + 3, units="m"))}, + ) + + expected = xr.DataArray(arr, dims="x", coords={"lat": ("x", arr + 3)}) + assert_identical(da.as_numpy(), expected) + np.testing.assert_equal(da.to_numpy(), arr) + np.testing.assert_equal(da["lat"].to_numpy(), arr + 3) + + @requires_sparse + def test_from_sparse(self): + import sparse + + arr = np.diagflat([1, 2, 3]) + sparr = sparse.COO.from_numpy(arr) + da = xr.DataArray( + sparr, dims=["x", "y"], coords={"elev": (("x", "y"), sparr + 3)} + ) + + expected = xr.DataArray( + arr, dims=["x", "y"], coords={"elev": (("x", "y"), arr + 3)} + ) + assert_identical(da.as_numpy(), expected) + np.testing.assert_equal(da.to_numpy(), arr) + + @requires_cupy + def test_from_cupy(self): + import cupy as cp + + arr = np.array([1, 2, 3]) + da = xr.DataArray( + cp.array(arr), dims="x", coords={"lat": ("x", cp.array(arr + 3))} + ) + + expected = xr.DataArray(arr, dims="x", coords={"lat": ("x", arr + 3)}) + assert_identical(da.as_numpy(), expected) + np.testing.assert_equal(da.to_numpy(), arr) + + @requires_dask + @requires_pint_0_15 + def test_from_pint_wrapping_dask(self): + import dask + from pint import Quantity + + arr = np.array([1, 2, 3]) + d = dask.array.from_array(arr) + da = xr.DataArray( + Quantity(d, units="Pa"), + dims="x", + coords={"lat": ("x", Quantity(d, units="m") * 2)}, + ) + + result = da.as_numpy() + result.name = None # remove dask-assigned name + expected = xr.DataArray(arr, dims="x", coords={"lat": ("x", arr * 2)}) + assert_identical(result, expected) + np.testing.assert_equal(da.to_numpy(), arr) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index b08ce9ea730..8e39bbdd83e 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -44,9 +44,11 @@ has_dask, requires_bottleneck, requires_cftime, + requires_cupy, requires_dask, requires_numbagg, requires_numexpr, + requires_pint_0_15, requires_scipy, requires_sparse, source_ndarray, @@ -728,7 +730,7 @@ def test_coords_modify(self): def test_update_index(self): actual = Dataset(coords={"x": [1, 2, 3]}) actual["x"] = ["a", "b", "c"] - assert actual.xindexes["x"].equals(pd.Index(["a", "b", "c"])) + assert actual.xindexes["x"].to_pandas_index().equals(pd.Index(["a", "b", "c"])) def test_coords_setitem_with_new_dimension(self): actual = Dataset() @@ -3191,13 +3193,13 @@ def test_update(self): data = create_test_data(seed=0) expected = data.copy() var2 = Variable("dim1", np.arange(8)) - actual = data.update({"var2": var2}) + actual = data + actual.update({"var2": var2}) expected["var2"] = var2 assert_identical(expected, actual) actual = data.copy() - actual_result = actual.update(data) - assert actual_result is actual + actual.update(data) assert_identical(expected, actual) other = Dataset(attrs={"new": "attr"}) @@ -3557,6 +3559,7 @@ def test_setitem_align_new_indexes(self): def test_setitem_str_dtype(self, dtype): ds = xr.Dataset(coords={"x": np.array(["x", "y"], dtype=dtype)}) + # test Dataset update ds["foo"] = xr.DataArray(np.array([0, 0]), dims=["x"]) assert np.issubdtype(ds.x.dtype, dtype) @@ -5195,10 +5198,19 @@ def test_dataset_transpose(self): expected_dims = tuple(d for d in new_order if d in ds[k].dims) assert actual[k].dims == expected_dims - with pytest.raises(ValueError, match=r"permuted"): - ds.transpose("dim1", "dim2", "dim3") - with pytest.raises(ValueError, match=r"permuted"): - ds.transpose("dim1", "dim2", "dim3", "time", "extra_dim") + # test missing dimension, raise error + with pytest.raises(ValueError): + ds.transpose(..., "not_a_dim") + + # test missing dimension, ignore error + actual = ds.transpose(..., "not_a_dim", missing_dims="ignore") + expected_ell = ds.transpose(...) + assert_identical(expected_ell, actual) + + # test missing dimension, raise warning + with pytest.warns(UserWarning): + actual = ds.transpose(..., "not_a_dim", missing_dims="warn") + assert_identical(expected_ell, actual) assert "T" not in dir(ds) @@ -6100,41 +6112,6 @@ def test_rolling_keep_attrs(funcname, argument): assert result.da_not_rolled.name == "da_not_rolled" -def test_rolling_keep_attrs_deprecated(): - global_attrs = {"units": "test", "long_name": "testing"} - attrs_da = {"da_attr": "test"} - - data = np.linspace(10, 15, 100) - coords = np.linspace(1, 10, 100) - - ds = Dataset( - data_vars={"da": ("coord", data)}, - coords={"coord": coords}, - attrs=global_attrs, - ) - ds.da.attrs = attrs_da - - # deprecated option - with pytest.warns( - FutureWarning, match="Passing ``keep_attrs`` to ``rolling`` is deprecated" - ): - result = ds.rolling(dim={"coord": 5}, keep_attrs=False).construct("window_dim") - - assert result.attrs == {} - assert result.da.attrs == {} - - # the keep_attrs in the reduction function takes precedence - with pytest.warns( - FutureWarning, match="Passing ``keep_attrs`` to ``rolling`` is deprecated" - ): - result = ds.rolling(dim={"coord": 5}, keep_attrs=True).construct( - "window_dim", keep_attrs=False - ) - - assert result.attrs == {} - assert result.da.attrs == {} - - def test_rolling_properties(ds): # catching invalid args with pytest.raises(ValueError, match="window must be > 0"): @@ -6580,9 +6557,6 @@ def test_integrate(dask): with pytest.raises(ValueError): da.integrate("x2d") - with pytest.warns(FutureWarning): - da.integrate(dim="x") - @requires_scipy @pytest.mark.parametrize("dask", [True, False]) @@ -6751,3 +6725,74 @@ def test_clip(ds): result = ds.clip(min=ds.mean("y"), max=ds.mean("y")) assert result.dims == ds.dims + + +class TestNumpyCoercion: + def test_from_numpy(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6])}) + + assert_identical(ds.as_numpy(), ds) + + @requires_dask + def test_from_dask(self): + ds = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6])}) + ds_chunked = ds.chunk(1) + + assert_identical(ds_chunked.as_numpy(), ds.compute()) + + @requires_pint_0_15 + def test_from_pint(self): + from pint import Quantity + + arr = np.array([1, 2, 3]) + ds = xr.Dataset( + {"a": ("x", Quantity(arr, units="Pa"))}, + coords={"lat": ("x", Quantity(arr + 3, units="m"))}, + ) + + expected = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"lat": ("x", arr + 3)}) + assert_identical(ds.as_numpy(), expected) + + @requires_sparse + def test_from_sparse(self): + import sparse + + arr = np.diagflat([1, 2, 3]) + sparr = sparse.COO.from_numpy(arr) + ds = xr.Dataset( + {"a": (["x", "y"], sparr)}, coords={"elev": (("x", "y"), sparr + 3)} + ) + + expected = xr.Dataset( + {"a": (["x", "y"], arr)}, coords={"elev": (("x", "y"), arr + 3)} + ) + assert_identical(ds.as_numpy(), expected) + + @requires_cupy + def test_from_cupy(self): + import cupy as cp + + arr = np.array([1, 2, 3]) + ds = xr.Dataset( + {"a": ("x", cp.array(arr))}, coords={"lat": ("x", cp.array(arr + 3))} + ) + + expected = xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"lat": ("x", arr + 3)}) + assert_identical(ds.as_numpy(), expected) + + @requires_dask + @requires_pint_0_15 + def test_from_pint_wrapping_dask(self): + import dask + from pint import Quantity + + arr = np.array([1, 2, 3]) + d = dask.array.from_array(arr) + ds = xr.Dataset( + {"a": ("x", Quantity(d, units="Pa"))}, + coords={"lat": ("x", Quantity(d, units="m") * 2)}, + ) + + result = ds.as_numpy() + expected = xr.Dataset({"a": ("x", arr)}, coords={"lat": ("x", arr * 2)}) + assert_identical(result, expected) diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 433e2e58de2..ab0d1d9f22c 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -184,11 +184,7 @@ def test_dask_distributed_cfgrib_integration_test(loop): assert_allclose(actual, expected) -@pytest.mark.skipif( - distributed.__version__ <= "1.19.3", - reason="Need recent distributed version to clean up get", -) -@gen_cluster(client=True, timeout=None) +@gen_cluster(client=True) async def test_async(c, s, a, b): x = create_test_data() assert not dask.is_dask_collection(x) diff --git a/xarray/tests/test_formatting_html.py b/xarray/tests/test_formatting_html.py index 47640ef2d95..09c6fa0cf3c 100644 --- a/xarray/tests/test_formatting_html.py +++ b/xarray/tests/test_formatting_html.py @@ -1,5 +1,3 @@ -from distutils.version import LooseVersion - import numpy as np import pandas as pd import pytest @@ -57,19 +55,9 @@ def test_short_data_repr_html_non_str_keys(dataset): def test_short_data_repr_html_dask(dask_dataarray): - import dask - - if LooseVersion(dask.__version__) < "2.0.0": - assert not hasattr(dask_dataarray.data, "_repr_html_") - data_repr = fh.short_data_repr_html(dask_dataarray) - assert ( - data_repr - == "dask.array<xarray-<this-array>, shape=(4, 6), dtype=float64, chunksize=(4, 6)>" - ) - else: - assert hasattr(dask_dataarray.data, "_repr_html_") - data_repr = fh.short_data_repr_html(dask_dataarray) - assert data_repr == dask_dataarray.data._repr_html_() + assert hasattr(dask_dataarray.data, "_repr_html_") + data_repr = fh.short_data_repr_html(dask_dataarray) + assert data_repr == dask_dataarray.data._repr_html_() def test_format_dims_no_dims(): diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index defc6212228..c8ba72a253f 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -2,7 +2,9 @@ import pandas as pd import pytest +import xarray as xr from xarray.core.indexes import PandasIndex, PandasMultiIndex, _asarray_tuplesafe +from xarray.core.variable import IndexVariable def test_asarray_tuplesafe(): @@ -18,9 +20,57 @@ def test_asarray_tuplesafe(): class TestPandasIndex: + def test_constructor(self): + pd_idx = pd.Index([1, 2, 3]) + index = PandasIndex(pd_idx, "x") + + assert index.index is pd_idx + assert index.dim == "x" + + def test_from_variables(self): + var = xr.Variable( + "x", [1, 2, 3], attrs={"unit": "m"}, encoding={"dtype": np.int32} + ) + + index, index_vars = PandasIndex.from_variables({"x": var}) + xr.testing.assert_identical(var.to_index_variable(), index_vars["x"]) + assert index.dim == "x" + assert index.index.equals(index_vars["x"].to_index()) + + var2 = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]]) + with pytest.raises(ValueError, match=r".*only accepts one variable.*"): + PandasIndex.from_variables({"x": var, "foo": var2}) + + with pytest.raises( + ValueError, match=r".*only accepts a 1-dimensional variable.*" + ): + PandasIndex.from_variables({"foo": var2}) + + def test_from_pandas_index(self): + pd_idx = pd.Index([1, 2, 3], name="foo") + + index, index_vars = PandasIndex.from_pandas_index(pd_idx, "x") + + assert index.dim == "x" + assert index.index is pd_idx + assert index.index.name == "foo" + xr.testing.assert_identical(index_vars["foo"], IndexVariable("x", [1, 2, 3])) + + # test no name set for pd.Index + pd_idx.name = None + index, index_vars = PandasIndex.from_pandas_index(pd_idx, "x") + assert "x" in index_vars + assert index.index is not pd_idx + assert index.index.name == "x" + + def to_pandas_index(self): + pd_idx = pd.Index([1, 2, 3], name="foo") + index = PandasIndex(pd_idx, "x") + assert index.to_pandas_index() is pd_idx + def test_query(self): # TODO: add tests that aren't just for edge cases - index = PandasIndex(pd.Index([1, 2, 3])) + index = PandasIndex(pd.Index([1, 2, 3]), "x") with pytest.raises(KeyError, match=r"not all values found"): index.query({"x": [0]}) with pytest.raises(KeyError): @@ -29,7 +79,9 @@ def test_query(self): index.query({"x": {"one": 0}}) def test_query_datetime(self): - index = PandasIndex(pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"])) + index = PandasIndex( + pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"]), "x" + ) actual = index.query({"x": "2001-01-01"}) expected = (1, None) assert actual == expected @@ -38,18 +90,96 @@ def test_query_datetime(self): assert actual == expected def test_query_unsorted_datetime_index_raises(self): - index = PandasIndex(pd.to_datetime(["2001", "2000", "2002"])) + index = PandasIndex(pd.to_datetime(["2001", "2000", "2002"]), "x") with pytest.raises(KeyError): # pandas will try to convert this into an array indexer. We should # raise instead, so we can be sure the result of indexing with a # slice is always a view. index.query({"x": slice("2001", "2002")}) + def test_equals(self): + index1 = PandasIndex([1, 2, 3], "x") + index2 = PandasIndex([1, 2, 3], "x") + assert index1.equals(index2) is True + + def test_union(self): + index1 = PandasIndex([1, 2, 3], "x") + index2 = PandasIndex([4, 5, 6], "y") + actual = index1.union(index2) + assert actual.index.equals(pd.Index([1, 2, 3, 4, 5, 6])) + assert actual.dim == "x" + + def test_intersection(self): + index1 = PandasIndex([1, 2, 3], "x") + index2 = PandasIndex([2, 3, 4], "y") + actual = index1.intersection(index2) + assert actual.index.equals(pd.Index([2, 3])) + assert actual.dim == "x" + + def test_copy(self): + expected = PandasIndex([1, 2, 3], "x") + actual = expected.copy() + + assert actual.index.equals(expected.index) + assert actual.index is not expected.index + assert actual.dim == expected.dim + + def test_getitem(self): + pd_idx = pd.Index([1, 2, 3]) + expected = PandasIndex(pd_idx, "x") + actual = expected[1:] + + assert actual.index.equals(pd_idx[1:]) + assert actual.dim == expected.dim + class TestPandasMultiIndex: + def test_from_variables(self): + v_level1 = xr.Variable( + "x", [1, 2, 3], attrs={"unit": "m"}, encoding={"dtype": np.int32} + ) + v_level2 = xr.Variable( + "x", ["a", "b", "c"], attrs={"unit": "m"}, encoding={"dtype": "U"} + ) + + index, index_vars = PandasMultiIndex.from_variables( + {"level1": v_level1, "level2": v_level2} + ) + + expected_idx = pd.MultiIndex.from_arrays([v_level1.data, v_level2.data]) + assert index.dim == "x" + assert index.index.equals(expected_idx) + + assert list(index_vars) == ["x", "level1", "level2"] + xr.testing.assert_equal(xr.IndexVariable("x", expected_idx), index_vars["x"]) + xr.testing.assert_identical(v_level1.to_index_variable(), index_vars["level1"]) + xr.testing.assert_identical(v_level2.to_index_variable(), index_vars["level2"]) + + var = xr.Variable(("x", "y"), [[1, 2, 3], [4, 5, 6]]) + with pytest.raises( + ValueError, match=r".*only accepts 1-dimensional variables.*" + ): + PandasMultiIndex.from_variables({"var": var}) + + v_level3 = xr.Variable("y", [4, 5, 6]) + with pytest.raises(ValueError, match=r"unmatched dimensions for variables.*"): + PandasMultiIndex.from_variables({"level1": v_level1, "level3": v_level3}) + + def test_from_pandas_index(self): + pd_idx = pd.MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]], names=("foo", "bar")) + + index, index_vars = PandasMultiIndex.from_pandas_index(pd_idx, "x") + + assert index.dim == "x" + assert index.index is pd_idx + assert index.index.names == ("foo", "bar") + xr.testing.assert_identical(index_vars["x"], IndexVariable("x", pd_idx)) + xr.testing.assert_identical(index_vars["foo"], IndexVariable("x", [1, 2, 3])) + xr.testing.assert_identical(index_vars["bar"], IndexVariable("x", [4, 5, 6])) + def test_query(self): index = PandasMultiIndex( - pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")) + pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=("one", "two")), "x" ) # test tuples inside slice are considered as scalar indexer values assert index.query({"x": slice(("a", 1), ("b", 2))}) == (slice(0, 4), None) diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 1909d309cf5..6e4fd320029 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -81,9 +81,12 @@ def test_group_indexers_by_index(self): def test_remap_label_indexers(self): def test_indexer(data, x, expected_pos, expected_idx=None): - pos, idx = indexing.remap_label_indexers(data, {"x": x}) + pos, new_idx_vars = indexing.remap_label_indexers(data, {"x": x}) + idx, _ = new_idx_vars.get("x", (None, None)) + if idx is not None: + idx = idx.to_pandas_index() assert_array_equal(pos.get("x"), expected_pos) - assert_array_equal(idx.get("x"), expected_idx) + assert_array_equal(idx, expected_idx) data = Dataset({"x": ("x", [1, 2, 3])}) mindex = pd.MultiIndex.from_product( diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 4f6dc616504..2029e6af05b 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -727,6 +727,7 @@ def test_datetime_interp_noerror(): @requires_cftime +@requires_scipy def test_3641(): times = xr.cftime_range("0001", periods=3, freq="500Y") da = xr.DataArray(range(3), dims=["time"], coords=[times]) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index a5ffb97db38..ee8bafb8fa7 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2912,3 +2912,41 @@ def test_maybe_gca(): assert existing_axes == ax # kwargs are ignored when reusing axes assert ax.get_aspect() == "auto" + + +@requires_matplotlib +@pytest.mark.parametrize( + "x, y, z, hue, markersize, row, col, add_legend, add_colorbar", + [ + ("A", "B", None, None, None, None, None, None, None), + ("B", "A", None, "w", None, None, None, True, None), + ("A", "B", None, "y", "x", None, None, True, True), + ("A", "B", "z", None, None, None, None, None, None), + ("B", "A", "z", "w", None, None, None, True, None), + ("A", "B", "z", "y", "x", None, None, True, True), + ("A", "B", "z", "y", "x", "w", None, True, True), + ], +) +def test_datarray_scatter(x, y, z, hue, markersize, row, col, add_legend, add_colorbar): + """Test datarray scatter. Merge with TestPlot1D eventually.""" + ds = xr.tutorial.scatter_example_dataset() + + extra_coords = [v for v in [x, hue, markersize] if v is not None] + + # Base coords: + coords = dict(ds.coords) + + # Add extra coords to the DataArray: + coords.update({v: ds[v] for v in extra_coords}) + + darray = xr.DataArray(ds[y], coords=coords) + + with figure_context(): + darray.plot._scatter( + x=x, + z=z, + hue=hue, + markersize=markersize, + add_legend=add_legend, + add_colorbar=add_colorbar, + ) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 17086049cc7..2140047f38e 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -5,10 +5,22 @@ import pandas as pd import pytest +try: + import matplotlib.pyplot as plt +except ImportError: + pass + import xarray as xr from xarray.core import dtypes, duck_array_ops -from . import assert_allclose, assert_duckarray_allclose, assert_equal, assert_identical +from . import ( + assert_allclose, + assert_duckarray_allclose, + assert_equal, + assert_identical, + requires_matplotlib, +) +from .test_plot import PlotTestCase from .test_variable import _PAD_XR_NP_ARGS pint = pytest.importorskip("pint") @@ -5564,3 +5576,29 @@ def test_merge(self, variant, unit, error, dtype): assert_units_equal(expected, actual) assert_equal(expected, actual) + + +@requires_matplotlib +class TestPlots(PlotTestCase): + def test_units_in_line_plot_labels(self): + arr = np.linspace(1, 10, 3) * unit_registry.Pa + # TODO make coord a Quantity once unit-aware indexes supported + x_coord = xr.DataArray( + np.linspace(1, 3, 3), dims="x", attrs={"units": "meters"} + ) + da = xr.DataArray(data=arr, dims="x", coords={"x": x_coord}, name="pressure") + + da.plot.line() + + ax = plt.gca() + assert ax.get_ylabel() == "pressure [pascal]" + assert ax.get_xlabel() == "x [meters]" + + def test_units_in_2d_plot_labels(self): + arr = np.ones((2, 3)) * unit_registry.Pa + da = xr.DataArray(data=arr, dims=["x", "y"], name="pressure") + + fig, (ax, cax) = plt.subplots(1, 2) + ax = da.plot.contourf(ax=ax, cbar_ax=cax, add_colorbar=True) + + assert cax.get_ylabel() == "pressure [pascal]" diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 9c78caea4d6..ce796e9de49 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -7,7 +7,6 @@ from xarray.coding.cftimeindex import CFTimeIndex from xarray.core import duck_array_ops, utils -from xarray.core.indexes import PandasIndex from xarray.core.utils import either_dict_or_kwargs, iterate_nested from . import assert_array_equal, requires_cftime, requires_dask @@ -29,13 +28,11 @@ def test_safe_cast_to_index(): dates = pd.date_range("2000-01-01", periods=10) x = np.arange(5) td = x * np.timedelta64(1, "D") - midx = pd.MultiIndex.from_tuples([(0,)], names=["a"]) for expected, array in [ (dates, dates.values), (pd.Index(x, dtype=object), x.astype(object)), (pd.Index(td), td), (pd.Index(td, dtype=object), td.astype(object)), - (midx, PandasIndex(midx)), ]: actual = utils.safe_cast_to_index(array) assert_array_equal(expected, actual) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 1e0dff45dd2..487c9b34336 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -11,7 +11,6 @@ from xarray import Coordinate, DataArray, Dataset, IndexVariable, Variable, set_options from xarray.core import dtypes, duck_array_ops, indexing from xarray.core.common import full_like, ones_like, zeros_like -from xarray.core.indexes import PandasIndex from xarray.core.indexing import ( BasicIndexer, CopyOnWriteArray, @@ -20,6 +19,7 @@ MemoryCachedArray, NumpyIndexingAdapter, OuterIndexer, + PandasIndexingAdapter, VectorizedIndexer, ) from xarray.core.pycompat import dask_array_type @@ -33,7 +33,9 @@ assert_equal, assert_identical, raise_if_dask_computes, + requires_cupy, requires_dask, + requires_pint_0_15, requires_sparse, source_ndarray, ) @@ -535,7 +537,7 @@ def test_copy_index(self): v = self.cls("x", midx) for deep in [True, False]: w = v.copy(deep=deep) - assert isinstance(w._data, PandasIndex) + assert isinstance(w._data, PandasIndexingAdapter) assert isinstance(w.to_index(), pd.MultiIndex) assert_array_equal(v._data.array, w._data.array) @@ -1160,7 +1162,7 @@ def test_as_variable(self): td = np.array([timedelta(days=x) for x in range(10)]) assert as_variable(td, "time").dtype.kind == "m" - with pytest.warns(DeprecationWarning): + with pytest.raises(TypeError): as_variable(("x", DataArray([]))) def test_repr(self): @@ -1466,6 +1468,20 @@ def test_transpose(self): w3 = Variable(["b", "c", "d", "a"], np.einsum("abcd->bcda", x)) assert_identical(w, w3.transpose("a", "b", "c", "d")) + # test missing dimension, raise error + with pytest.raises(ValueError): + v.transpose(..., "not_a_dim") + + # test missing dimension, ignore error + actual = v.transpose(..., "not_a_dim", missing_dims="ignore") + expected_ell = v.transpose(...) + assert_identical(expected_ell, actual) + + # test missing dimension, raise warning + with pytest.warns(UserWarning): + v.transpose(..., "not_a_dim", missing_dims="warn") + assert_identical(expected_ell, actual) + def test_transpose_0d(self): for value in [ 3.5, @@ -2145,7 +2161,7 @@ def test_multiindex_default_level_names(self): def test_data(self): x = IndexVariable("x", np.arange(3.0)) - assert isinstance(x._data, PandasIndex) + assert isinstance(x._data, PandasIndexingAdapter) assert isinstance(x.data, np.ndarray) assert float == x.dtype assert_array_equal(np.arange(3), x) @@ -2287,7 +2303,7 @@ def test_coarsen_2d(self): class TestAsCompatibleData: def test_unchanged_types(self): - types = (np.asarray, PandasIndex, LazilyIndexedArray) + types = (np.asarray, PandasIndexingAdapter, LazilyIndexedArray) for t in types: for data in [ np.arange(3), @@ -2540,3 +2556,68 @@ def test_clip(var): var.mean("z").data[:, :, np.newaxis], ), ) + + +@pytest.mark.parametrize("Var", [Variable, IndexVariable]) +class TestNumpyCoercion: + def test_from_numpy(self, Var): + v = Var("x", [1, 2, 3]) + + assert_identical(v.as_numpy(), v) + np.testing.assert_equal(v.to_numpy(), np.array([1, 2, 3])) + + @requires_dask + def test_from_dask(self, Var): + v = Var("x", [1, 2, 3]) + v_chunked = v.chunk(1) + + assert_identical(v_chunked.as_numpy(), v.compute()) + np.testing.assert_equal(v.to_numpy(), np.array([1, 2, 3])) + + @requires_pint_0_15 + def test_from_pint(self, Var): + from pint import Quantity + + arr = np.array([1, 2, 3]) + v = Var("x", Quantity(arr, units="m")) + + assert_identical(v.as_numpy(), Var("x", arr)) + np.testing.assert_equal(v.to_numpy(), arr) + + @requires_sparse + def test_from_sparse(self, Var): + if Var is IndexVariable: + pytest.skip("Can't have 2D IndexVariables") + + import sparse + + arr = np.diagflat([1, 2, 3]) + sparr = sparse.COO(coords=[[0, 1, 2], [0, 1, 2]], data=[1, 2, 3]) + v = Variable(["x", "y"], sparr) + + assert_identical(v.as_numpy(), Variable(["x", "y"], arr)) + np.testing.assert_equal(v.to_numpy(), arr) + + @requires_cupy + def test_from_cupy(self, Var): + import cupy as cp + + arr = np.array([1, 2, 3]) + v = Var("x", cp.array(arr)) + + assert_identical(v.as_numpy(), Var("x", arr)) + np.testing.assert_equal(v.to_numpy(), arr) + + @requires_dask + @requires_pint_0_15 + def test_from_pint_wrapping_dask(self, Var): + import dask + from pint import Quantity + + arr = np.array([1, 2, 3]) + d = dask.array.from_array(np.array([1, 2, 3])) + v = Var("x", Quantity(d, units="m")) + + result = v.as_numpy() + assert_identical(result, Var("x", arr)) + np.testing.assert_equal(v.to_numpy(), arr)