From db36c5c0cdee2f5313a81fdeca8a8ae5491d1c8f Mon Sep 17 00:00:00 2001 From: hazbottles Date: Fri, 3 Jan 2020 12:16:44 +0000 Subject: [PATCH 01/29] add multiindex level name checking to .rename() (#3658) * add multiindex level name checking to .rename() * update whats-new.rst --- doc/whats-new.rst | 2 ++ xarray/core/dataset.py | 9 ++++++++- xarray/tests/test_dataset.py | 8 ++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 00d1c50780e..70853dbb730 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -62,6 +62,8 @@ Bug fixes By `Tom Augspurger `_. - Ensure :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` issue the correct error when ``q`` is out of bounds (:issue:`3634`) by `Mathias Hauser `_. +- :py:meth:`Dataset.rename`, :py:meth:`DataArray.rename` now check for conflicts with + MultiIndex level names. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 6be06fed117..032c66d3778 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -89,7 +89,13 @@ is_scalar, maybe_wrap_array, ) -from .variable import IndexVariable, Variable, as_variable, broadcast_variables +from .variable import ( + IndexVariable, + Variable, + as_variable, + broadcast_variables, + assert_unique_multiindex_level_names, +) if TYPE_CHECKING: from ..backends import AbstractDataStore, ZarrStore @@ -2780,6 +2786,7 @@ def rename( variables, coord_names, dims, indexes = self._rename_all( name_dict=name_dict, dims_dict=name_dict ) + assert_unique_multiindex_level_names(variables) return self._replace(variables, coord_names, dims=dims, indexes=indexes) def rename_dims( diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 7db1911621b..edda4cb2a58 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2461,6 +2461,14 @@ def test_rename_vars(self): with pytest.raises(ValueError): original.rename_vars(names_dict_bad) + def test_rename_multiindex(self): + mindex = pd.MultiIndex.from_tuples( + [([1, 2]), ([3, 4])], names=["level0", "level1"] + ) + data = Dataset({}, {"x": mindex}) + with raises_regex(ValueError, "conflicting MultiIndex"): + data.rename({"x": "level0"}) + @requires_cftime def test_rename_does_not_change_CFTimeIndex_type(self): # make sure CFTimeIndex is not converted to DatetimeIndex #3522 From 8fc9ecedc87b5d878363b233e260c87fd632fa0f Mon Sep 17 00:00:00 2001 From: Riley Brady Date: Wed, 8 Jan 2020 10:50:03 -0700 Subject: [PATCH 02/29] Add map_blocks example to docs. (#3667) --- xarray/core/parallel.py | 42 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index dd6c67338d8..e4fb5803191 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -154,6 +154,48 @@ def map_blocks( -------- dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks, xarray.DataArray.map_blocks + + Examples + -------- + + Calculate an anomaly from climatology using ``.groupby()``. Using + ``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``, + its indices, and its methods like ``.groupby()``. + + >>> def calculate_anomaly(da, groupby_type='time.month'): + ... # Necessary workaround to xarray's check with zero dimensions + ... # https://github.com/pydata/xarray/issues/3575 + ... if sum(da.shape) == 0: + ... return da + ... gb = da.groupby(groupby_type) + ... clim = gb.mean(dim='time') + ... return gb - clim + >>> time = xr.cftime_range('1990-01', '1992-01', freq='M') + >>> np.random.seed(123) + >>> array = xr.DataArray(np.random.rand(len(time)), + ... dims="time", coords=[time]).chunk() + >>> xr.map_blocks(calculate_anomaly, array).compute() + + array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862, + 0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714, + -0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 , + 0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108, + 0.07673453, 0.22865714, 0.19063865, -0.0590131 ]) + Coordinates: + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + + Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments + to the function being applied in ``xr.map_blocks()``: + + >>> xr.map_blocks(calculate_anomaly, array, kwargs={'groupby_type': 'time.year'}) + + array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 , + -0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425, + -0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273, + 0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 , + 0.14482397, 0.35985481, 0.23487834, 0.12144652]) + Coordinates: + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 """ def _wrapper(func, obj, to_array, args, kwargs): From 080caf4246fe2f4d6aa0c5dcb65a99b376fa669b Mon Sep 17 00:00:00 2001 From: keewis Date: Wed, 8 Jan 2020 19:27:29 +0100 Subject: [PATCH 03/29] Support swap_dims to dimension names that are not existing variables (#3636) * test that swapping to a non-existing name works * don't try to get a variable if the variable does not exist * don't add dimensions to coord_names if they are not existing variables * add whats-new.rst entry * update the documentation --- doc/whats-new.rst | 3 +++ xarray/core/dataarray.py | 10 ++++++++-- xarray/core/dataset.py | 17 +++++++++++++---- xarray/tests/test_dataarray.py | 5 +++++ xarray/tests/test_dataset.py | 6 ++++++ 5 files changed, 35 insertions(+), 6 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 70853dbb730..351424fbb9f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,6 +37,9 @@ New Features - Added the ``count`` reduction method to both :py:class:`~core.rolling.DatasetCoarsen` and :py:class:`~core.rolling.DataArrayCoarsen` objects. (:pull:`3500`) By `Deepak Cherian `_ +- :py:meth:`Dataset.swap_dims` and :py:meth:`DataArray.swap_dims` + now allow swapping to dimension names that don't exist yet. (:pull:`3636`) + By `Justus Magin `_. - Extend :py:class:`core.accessor_dt.DatetimeAccessor` properties and support `.dt` accessor for timedelta via :py:class:`core.accessor_dt.TimedeltaAccessor` (:pull:`3612`) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 31aa4da57b2..cbd8d243385 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1480,8 +1480,7 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray": ---------- dims_dict : dict-like Dictionary whose keys are current dimension names and whose values - are new names. Each value must already be a coordinate on this - array. + are new names. Returns ------- @@ -1504,6 +1503,13 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray": Coordinates: x (y) >> arr.swap_dims({"x": "z"}) + + array([0, 1]) + Coordinates: + x (z) >> ds.swap_dims({"x": "z"}) + + Dimensions: (z: 2) + Coordinates: + x (z) Date: Thu, 9 Jan 2020 02:46:45 +0100 Subject: [PATCH 04/29] raise an error when renaming dimensions to existing names (#3645) * check we raise an error if the name already exists * raise if the new name already exists and point to swap_dims * update the documentation * whats-new.rst * fix the docstring of rename_dims Co-Authored-By: Deepak Cherian --- doc/whats-new.rst | 3 +++ xarray/core/dataset.py | 10 ++++++++-- xarray/tests/test_dataset.py | 3 +++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 351424fbb9f..5a9f2497ed6 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -65,6 +65,9 @@ Bug fixes By `Tom Augspurger `_. - Ensure :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` issue the correct error when ``q`` is out of bounds (:issue:`3634`) by `Mathias Hauser `_. +- Raise an error when trying to use :py:meth:`Dataset.rename_dims` to + rename to an existing name (:issue:`3438`, :pull:`3645`) + By `Justus Magin `_. - :py:meth:`Dataset.rename`, :py:meth:`DataArray.rename` now check for conflicts with MultiIndex level names. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 607350c8101..ac0a923db78 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2798,7 +2798,8 @@ def rename_dims( ---------- dims_dict : dict-like, optional Dictionary whose keys are current dimension names and - whose values are the desired names. + whose values are the desired names. The desired names must + not be the name of an existing dimension or Variable in the Dataset. **dims, optional Keyword form of ``dims_dict``. One of dims_dict or dims must be provided. @@ -2816,12 +2817,17 @@ def rename_dims( DataArray.rename """ dims_dict = either_dict_or_kwargs(dims_dict, dims, "rename_dims") - for k in dims_dict: + for k, v in dims_dict.items(): if k not in self.dims: raise ValueError( "cannot rename %r because it is not a " "dimension in this dataset" % k ) + if v in self.dims or v in self: + raise ValueError( + f"Cannot rename {k} to {v} because {v} already exists. " + "Try using swap_dims instead." + ) variables, coord_names, sizes, indexes = self._rename_all( name_dict={}, dims_dict=dims_dict diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 48d8c25b810..2220abbef31 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2444,6 +2444,9 @@ def test_rename_dims(self): with pytest.raises(ValueError): original.rename_dims(dims_dict_bad) + with pytest.raises(ValueError): + original.rename_dims({"x": "z"}) + def test_rename_vars(self): original = Dataset({"x": ("x", [0, 1, 2]), "y": ("x", [10, 11, 12]), "z": 42}) expected = Dataset( From 24f9292114894621d8eb7a4eade6347538ce0d23 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 10 Jan 2020 16:10:56 +0000 Subject: [PATCH 05/29] Make dask names change when chunking Variables by different amounts. (#3584) * Make dask tokens change when chunking Variables by different amounts. When rechunking by the current chunk size, the dask token should not change. Add a __dask_tokenize__ method for ReprObject so that this behaviour is present when DataArrays are converted to temporary Datasets and back. Co-Authored-By: crusaderky Co-authored-by: crusaderky --- doc/whats-new.rst | 4 ++++ xarray/core/dataset.py | 5 ++++- xarray/core/utils.py | 7 ++++++- xarray/tests/test_dask.py | 18 +++++++++--------- xarray/tests/test_dataarray.py | 7 +++++++ xarray/tests/test_dataset.py | 16 ++++++++++++++++ 6 files changed, 46 insertions(+), 11 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 5a9f2497ed6..e1c4e3dd9ac 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -47,6 +47,7 @@ New Features Bug fixes ~~~~~~~~~ + - Fix :py:meth:`xarray.combine_by_coords` to allow for combining incomplete hypercubes of Datasets (:issue:`3648`). By `Ian Bolliger `_. @@ -91,6 +92,9 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ +- Make sure dask names change when rechunking by different chunk sizes. Conversely, make sure they + stay the same when rechunking by the same chunk size. (:issue:`3350`) + By `Deepak Cherian `_. - 2x to 5x speed boost (on small arrays) for :py:meth:`Dataset.isel`, :py:meth:`DataArray.isel`, and :py:meth:`DataArray.__getitem__` when indexing by int, slice, list of int, scalar ndarray, or 1-dimensional ndarray. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ac0a923db78..129f0c0f7a7 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1754,7 +1754,10 @@ def maybe_chunk(name, var, chunks): if not chunks: chunks = None if var.ndim > 0: - token2 = tokenize(name, token if token else var._data) + # when rechunking by different amounts, make sure dask names change + # by provinding chunks as an input to tokenize. + # subtle bugs result otherwise. see GH3350 + token2 = tokenize(name, token if token else var._data, chunks) name2 = f"{name_prefix}{name}-{token2}" return var.chunk(chunks, name=name2, lock=lock) else: diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 6681375c18e..e335365d5ca 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -547,7 +547,12 @@ def __eq__(self, other) -> bool: return False def __hash__(self) -> int: - return hash((ReprObject, self._value)) + return hash((type(self), self._value)) + + def __dask_tokenize__(self): + from dask.base import normalize_token + + return normalize_token((type(self), self._value)) @contextlib.contextmanager diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index d0e2654eed3..cc554850839 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1083,7 +1083,7 @@ def func(obj): actual = xr.map_blocks(func, obj) expected = func(obj) assert_chunks_equal(expected.chunk(), actual) - xr.testing.assert_identical(actual.compute(), expected.compute()) + assert_identical(actual, expected) @pytest.mark.parametrize("obj", [make_da(), make_ds()]) @@ -1092,7 +1092,7 @@ def test_map_blocks_convert_args_to_list(obj): with raise_if_dask_computes(): actual = xr.map_blocks(operator.add, obj, [10]) assert_chunks_equal(expected.chunk(), actual) - xr.testing.assert_identical(actual.compute(), expected.compute()) + assert_identical(actual, expected) @pytest.mark.parametrize("obj", [make_da(), make_ds()]) @@ -1107,7 +1107,7 @@ def add_attrs(obj): with raise_if_dask_computes(): actual = xr.map_blocks(add_attrs, obj) - xr.testing.assert_identical(actual.compute(), expected.compute()) + assert_identical(actual, expected) def test_map_blocks_change_name(map_da): @@ -1120,7 +1120,7 @@ def change_name(obj): with raise_if_dask_computes(): actual = xr.map_blocks(change_name, map_da) - xr.testing.assert_identical(actual.compute(), expected.compute()) + assert_identical(actual, expected) @pytest.mark.parametrize("obj", [make_da(), make_ds()]) @@ -1129,7 +1129,7 @@ def test_map_blocks_kwargs(obj): with raise_if_dask_computes(): actual = xr.map_blocks(xr.full_like, obj, kwargs=dict(fill_value=np.nan)) assert_chunks_equal(expected.chunk(), actual) - xr.testing.assert_identical(actual.compute(), expected.compute()) + assert_identical(actual, expected) def test_map_blocks_to_array(map_ds): @@ -1137,7 +1137,7 @@ def test_map_blocks_to_array(map_ds): actual = xr.map_blocks(lambda x: x.to_array(), map_ds) # to_array does not preserve name, so cannot use assert_identical - assert_equal(actual.compute(), map_ds.to_array().compute()) + assert_equal(actual, map_ds.to_array()) @pytest.mark.parametrize( @@ -1156,7 +1156,7 @@ def test_map_blocks_da_transformations(func, map_da): with raise_if_dask_computes(): actual = xr.map_blocks(func, map_da) - assert_identical(actual.compute(), func(map_da).compute()) + assert_identical(actual, func(map_da)) @pytest.mark.parametrize( @@ -1175,7 +1175,7 @@ def test_map_blocks_ds_transformations(func, map_ds): with raise_if_dask_computes(): actual = xr.map_blocks(func, map_ds) - assert_identical(actual.compute(), func(map_ds).compute()) + assert_identical(actual, func(map_ds)) @pytest.mark.parametrize("obj", [make_da(), make_ds()]) @@ -1188,7 +1188,7 @@ def func(obj): expected = xr.map_blocks(func, obj) actual = obj.map_blocks(func) - assert_identical(expected.compute(), actual.compute()) + assert_identical(expected, actual) def test_map_blocks_hlg_layers(): diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 4189c3b504a..786eb5007a6 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -752,12 +752,19 @@ def test_chunk(self): blocked = unblocked.chunk() assert blocked.chunks == ((3,), (4,)) + first_dask_name = blocked.data.name blocked = unblocked.chunk(chunks=((2, 1), (2, 2))) assert blocked.chunks == ((2, 1), (2, 2)) + assert blocked.data.name != first_dask_name blocked = unblocked.chunk(chunks=(3, 3)) assert blocked.chunks == ((3,), (3, 1)) + assert blocked.data.name != first_dask_name + + # name doesn't change when rechunking by same amount + # this fails if ReprObject doesn't have __dask_tokenize__ defined + assert unblocked.chunk(2).data.name == unblocked.chunk(2).data.name assert blocked.load().chunks is None diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 2220abbef31..c953f5d22e9 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -936,19 +936,35 @@ def test_chunk(self): expected_chunks = {"dim1": (8,), "dim2": (9,), "dim3": (10,)} assert reblocked.chunks == expected_chunks + def get_dask_names(ds): + return {k: v.data.name for k, v in ds.items()} + + orig_dask_names = get_dask_names(reblocked) + reblocked = data.chunk({"time": 5, "dim1": 5, "dim2": 5, "dim3": 5}) # time is not a dim in any of the data_vars, so it # doesn't get chunked expected_chunks = {"dim1": (5, 3), "dim2": (5, 4), "dim3": (5, 5)} assert reblocked.chunks == expected_chunks + # make sure dask names change when rechunking by different amounts + # regression test for GH3350 + new_dask_names = get_dask_names(reblocked) + for k, v in new_dask_names.items(): + assert v != orig_dask_names[k] + reblocked = data.chunk(expected_chunks) assert reblocked.chunks == expected_chunks # reblock on already blocked data + orig_dask_names = get_dask_names(reblocked) reblocked = reblocked.chunk(expected_chunks) + new_dask_names = get_dask_names(reblocked) assert reblocked.chunks == expected_chunks assert_identical(reblocked, data) + # recuhnking with same chunk sizes should not change names + for k, v in new_dask_names.items(): + assert v == orig_dask_names[k] with raises_regex(ValueError, "some chunks"): data.chunk({"foo": 10}) From e8dbe6ee8cf3d94af0a35e92f7d77b08dc4b95df Mon Sep 17 00:00:00 2001 From: Riley Brady Date: Fri, 10 Jan 2020 09:48:31 -0700 Subject: [PATCH 06/29] Add map_blocks example to whats-new (#3682) --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e1c4e3dd9ac..93df2b6569a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -89,6 +89,8 @@ Documentation - Added examples for :py:meth:`DataArray.quantile`, :py:meth:`Dataset.quantile` and ``GroupBy.quantile``. (:pull:`3576`) By `Justus Magin `_. +- Added example for :py:func:`~xarray.map_blocks`. (:pull:`3667`) + By `Riley X. Brady `_. Internal Changes ~~~~~~~~~~~~~~~~ From ff75081304eb2e2784dcb229cc48a532da557896 Mon Sep 17 00:00:00 2001 From: rpgoldman Date: Fri, 10 Jan 2020 14:02:22 -0600 Subject: [PATCH 07/29] How do I add a new variable to dataset. (#3679) * How do I add a new variable to dataset. * Update doc/howdoi.rst Improvement from dcherian Co-Authored-By: Deepak Cherian * Add cross-reference per suggestion. * Fix cross-reference. * Update doc/howdoi.rst Suggestion from keewis Co-Authored-By: keewis Co-authored-by: Deepak Cherian Co-authored-by: keewis --- doc/data-structures.rst | 2 ++ doc/howdoi.rst | 3 +++ 2 files changed, 5 insertions(+) diff --git a/doc/data-structures.rst b/doc/data-structures.rst index 504d820a234..70e34adabed 100644 --- a/doc/data-structures.rst +++ b/doc/data-structures.rst @@ -353,6 +353,8 @@ setting) variables and attributes: This is particularly useful in an exploratory context, because you can tab-complete these variable names with tools like IPython. +.. _dictionary_like_methods: + Dictionary like methods ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/howdoi.rst b/doc/howdoi.rst index 80266bd3b84..84c0c786027 100644 --- a/doc/howdoi.rst +++ b/doc/howdoi.rst @@ -11,6 +11,8 @@ How do I ... * - How do I... - Solution + * - add a DataArray to my dataset as a new variable + - ``my_dataset[varname] = my_dataArray`` or :py:meth:`Dataset.assign` (see also :ref:`dictionary_like_methods`) * - add variables from other datasets to my dataset - :py:meth:`Dataset.merge` * - add a new dimension and/or coordinate @@ -57,3 +59,4 @@ How do I ... - ``obj.dt.ceil``, ``obj.dt.floor``, ``obj.dt.round``. See :ref:`dt_accessor` for more. * - make a mask that is ``True`` where an object contains any of the values in a array - :py:meth:`Dataset.isin`, :py:meth:`DataArray.isin` + From 099c0901e29927935440692b1363ef08fe6d2e42 Mon Sep 17 00:00:00 2001 From: Julien Seguinot Date: Sat, 11 Jan 2020 16:22:55 +0100 Subject: [PATCH 08/29] Add option to choose mfdataset attributes source. (#3498) * Add 'master_file' kwarg in open_mfdataset, which can be a str or Path to a particular data file. --- doc/whats-new.rst | 3 +++ xarray/backends/api.py | 19 +++++++++++++++--- xarray/tests/test_backends.py | 36 +++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 93df2b6569a..eacf8433c0a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,6 +37,9 @@ New Features - Added the ``count`` reduction method to both :py:class:`~core.rolling.DatasetCoarsen` and :py:class:`~core.rolling.DataArrayCoarsen` objects. (:pull:`3500`) By `Deepak Cherian `_ +- Add `attrs_file` option in :py:func:`~xarray.open_mfdataset` to choose the + source file for global attributes in a multi-file dataset (:issue:`2382`, + :pull:`3498`) by `Julien Seguinot _`. - :py:meth:`Dataset.swap_dims` and :py:meth:`DataArray.swap_dims` now allow swapping to dimension names that don't exist yet. (:pull:`3636`) By `Justus Magin `_. diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 23d09ba5e33..eea1fb9ddce 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -718,6 +718,7 @@ def open_mfdataset( autoclose=None, parallel=False, join="outer", + attrs_file=None, **kwargs, ): """Open multiple files as a single dataset. @@ -729,8 +730,8 @@ def open_mfdataset( ``combine_by_coords`` and ``combine_nested``. By default the old (now deprecated) ``auto_combine`` will be used, please specify either ``combine='by_coords'`` or ``combine='nested'`` in future. Requires dask to be installed. See documentation for - details on dask [1]_. Attributes from the first dataset file are used for the - combined dataset. + details on dask [1]_. Global attributes from the ``attrs_file`` are used + for the combined dataset. Parameters ---------- @@ -827,6 +828,10 @@ def open_mfdataset( - 'override': if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. + attrs_file : str or pathlib.Path, optional + Path of the file used to read global attributes from. + By default global attributes are read from the first file provided, + with wildcard matches sorted by filename. **kwargs : optional Additional arguments passed on to :py:func:`xarray.open_dataset`. @@ -961,7 +966,15 @@ def open_mfdataset( raise combined._file_obj = _MultiFileCloser(file_objs) - combined.attrs = datasets[0].attrs + + # read global attributes from the attrs_file or from the first dataset + if attrs_file is not None: + if isinstance(attrs_file, Path): + attrs_file = str(attrs_file) + combined.attrs = datasets[paths.index(attrs_file)].attrs + else: + combined.attrs = datasets[0].attrs + return combined diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index a23527bd49a..4fccdf2dd6c 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2832,6 +2832,42 @@ def test_attrs_mfdataset(self): with raises_regex(AttributeError, "no attribute"): actual.test2 + def test_open_mfdataset_attrs_file(self): + original = Dataset({"foo": ("x", np.random.randn(10))}) + with create_tmp_files(2) as (tmp1, tmp2): + ds1 = original.isel(x=slice(5)) + ds2 = original.isel(x=slice(5, 10)) + ds1.attrs["test1"] = "foo" + ds2.attrs["test2"] = "bar" + ds1.to_netcdf(tmp1) + ds2.to_netcdf(tmp2) + with open_mfdataset( + [tmp1, tmp2], concat_dim="x", combine="nested", attrs_file=tmp2 + ) as actual: + # attributes are inherited from the master file + assert actual.attrs["test2"] == ds2.attrs["test2"] + # attributes from ds1 are not retained, e.g., + assert "test1" not in actual.attrs + + def test_open_mfdataset_attrs_file_path(self): + original = Dataset({"foo": ("x", np.random.randn(10))}) + with create_tmp_files(2) as (tmp1, tmp2): + tmp1 = Path(tmp1) + tmp2 = Path(tmp2) + ds1 = original.isel(x=slice(5)) + ds2 = original.isel(x=slice(5, 10)) + ds1.attrs["test1"] = "foo" + ds2.attrs["test2"] = "bar" + ds1.to_netcdf(tmp1) + ds2.to_netcdf(tmp2) + with open_mfdataset( + [tmp1, tmp2], concat_dim="x", combine="nested", attrs_file=tmp2 + ) as actual: + # attributes are inherited from the master file + assert actual.attrs["test2"] == ds2.attrs["test2"] + # attributes from ds1 are not retained, e.g., + assert "test1" not in actual.attrs + def test_open_mfdataset_auto_combine(self): original = Dataset({"foo": ("x", np.random.randn(10)), "x": np.arange(10)}) with create_tmp_file() as tmp1: From 40fab1b303d40c37d7307342cebbceb1ea8cc608 Mon Sep 17 00:00:00 2001 From: David Caron Date: Sat, 11 Jan 2020 18:15:44 -0500 Subject: [PATCH 09/29] fix docstring for combine_first: returns a Dataset (#3683) --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 129f0c0f7a7..c314eb41458 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4152,7 +4152,7 @@ def combine_first(self, other: "Dataset") -> "Dataset": Returns ------- - DataArray + Dataset """ out = ops.fillna(self, other, join="outer", dataset_join="outer") return out From 1689db493f10262555196f658c52e370aacb4a33 Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Sun, 12 Jan 2020 13:04:01 +0000 Subject: [PATCH 10/29] ds.merge(da) bugfix (#3677) * Added mwe as test * Cast to Dataset * Updated what's new * black formatted * Use assert_identical --- doc/whats-new.rst | 2 ++ xarray/core/dataset.py | 1 + xarray/tests/test_merge.py | 7 +++++++ 3 files changed, 10 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index eacf8433c0a..6eeb55c23cb 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -74,6 +74,8 @@ Bug fixes By `Justus Magin `_. - :py:meth:`Dataset.rename`, :py:meth:`DataArray.rename` now check for conflicts with MultiIndex level names. +- :py:meth:`Dataset.merge` no longer fails when passed a `DataArray` instead of a `Dataset` object. + By `Tom Nicholas `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index c314eb41458..6ecd9b59f8e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3607,6 +3607,7 @@ def merge( If any variables conflict (see ``compat``). """ _check_inplace(inplace) + other = other.to_dataset() if isinstance(other, xr.DataArray) else other merge_result = dataset_merge_method( self, other, diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index c1e6c7a5ce8..6c8f3f65657 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -3,6 +3,7 @@ import xarray as xr from xarray.core import dtypes, merge +from xarray.testing import assert_identical from . import raises_regex from .test_dataset import create_test_data @@ -253,3 +254,9 @@ def test_merge_no_conflicts(self): with pytest.raises(xr.MergeError): ds3 = xr.Dataset({"a": ("y", [2, 3]), "y": [1, 2]}) ds1.merge(ds3, compat="no_conflicts") + + def test_merge_dataarray(self): + ds = xr.Dataset({"a": 0}) + da = xr.DataArray(data=1, name="b") + + assert_identical(ds.merge(da), xr.merge([ds, da])) From 59d3ba5e938bafb4a1981c1a56d42aa31041df0a Mon Sep 17 00:00:00 2001 From: Spencer Clark Date: Mon, 13 Jan 2020 11:31:37 -0500 Subject: [PATCH 11/29] Explicitly convert result of pd.to_datetime to a timezone-naive type (#3688) --- xarray/tests/test_coding_times.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index d012fb36c35..00c34940ce4 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -451,7 +451,7 @@ def test_cf_datetime_nan(num_dates, units, expected_list): warnings.filterwarnings("ignore", "All-NaN") actual = coding.times.decode_cf_datetime(num_dates, units) # use pandas because numpy will deprecate timezone-aware conversions - expected = pd.to_datetime(expected_list) + expected = pd.to_datetime(expected_list).to_numpy(dtype="datetime64[ns]") assert_array_equal(expected, actual) From 8a650a11d1f859a88cc91b8815c16597203892aa Mon Sep 17 00:00:00 2001 From: Tom Nicholas <35968931+TomNicholas@users.noreply.github.com> Date: Mon, 13 Jan 2020 16:33:04 +0000 Subject: [PATCH 12/29] Fix mypy type checking tests failure in ds.merge (#3690) * Added DataArray as valid type input to ds.merge method * Fix import error by specifying type as string --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 6ecd9b59f8e..2d15ff586e4 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3550,7 +3550,7 @@ def update(self, other: "CoercibleMapping", inplace: bool = None) -> "Dataset": def merge( self, - other: "CoercibleMapping", + other: Union["CoercibleMapping", "DataArray"], inplace: bool = None, overwrite_vars: Union[Hashable, Iterable[Hashable]] = frozenset(), compat: str = "no_conflicts", From 40423457928245d216f973530904df1c93110f6c Mon Sep 17 00:00:00 2001 From: Emmanuel Roux <15956441+fesaille@users.noreply.github.com> Date: Mon, 13 Jan 2020 21:32:17 +0100 Subject: [PATCH 13/29] Typo on DataSet/DataArray.to_dict documentation (#3692) --- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index cbd8d243385..15db6ed468e 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2368,7 +2368,7 @@ def to_dict(self, data: bool = True) -> dict: naming conventions. Converts all variables and attributes to native Python objects. - Useful for coverting to json. To avoid datetime incompatibility + Useful for converting to json. To avoid datetime incompatibility use decode_times=False kwarg in xarrray.open_dataset. Parameters diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2d15ff586e4..1aaad02b470 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4667,7 +4667,7 @@ def to_dict(self, data=True): conventions. Converts all variables and attributes to native Python objects - Useful for coverting to json. To avoid datetime incompatibility + Useful for converting to json. To avoid datetime incompatibility use decode_times=False kwarg in xarrray.open_dataset. Parameters From e0fd48052dbda34ee35d2491e4fe856495c9621b Mon Sep 17 00:00:00 2001 From: keewis Date: Tue, 14 Jan 2020 17:13:23 +0100 Subject: [PATCH 14/29] allow passing any iterable to drop when dropping variables (#3693) * allow passing any iterable to drop when dropping variables * whats-new.rst * update whats-new.rst --- doc/whats-new.rst | 3 +++ xarray/core/dataset.py | 3 +-- xarray/tests/test_dataset.py | 4 ++++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 6eeb55c23cb..1ad344a208d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -76,6 +76,9 @@ Bug fixes MultiIndex level names. - :py:meth:`Dataset.merge` no longer fails when passed a `DataArray` instead of a `Dataset` object. By `Tom Nicholas `_. +- Fix a regression in :py:meth:`Dataset.drop`: allow passing any + iterable when dropping variables (:issue:`3552`, :pull:`3693`) + By `Justus Magin `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 1aaad02b470..79f1030fabe 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -85,7 +85,6 @@ either_dict_or_kwargs, hashable, is_dict_like, - is_list_like, is_scalar, maybe_wrap_array, ) @@ -3690,7 +3689,7 @@ def drop(self, labels=None, dim=None, *, errors="raise", **labels_kwargs): raise ValueError("cannot specify dim and dict-like arguments.") labels = either_dict_or_kwargs(labels, labels_kwargs, "drop") - if dim is None and (is_list_like(labels) or is_scalar(labels)): + if dim is None and (is_scalar(labels) or isinstance(labels, Iterable)): warnings.warn( "dropping variables using `drop` will be deprecated; using drop_vars is encouraged.", PendingDeprecationWarning, diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index c953f5d22e9..f9eb37dbf2f 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2167,6 +2167,10 @@ def test_drop_variables(self): actual = data.drop(["time", "not_found_here"], errors="ignore") assert_identical(expected, actual) + with pytest.warns(PendingDeprecationWarning): + actual = data.drop({"time", "not_found_here"}, errors="ignore") + assert_identical(expected, actual) + def test_drop_index_labels(self): data = Dataset({"A": (["x", "y"], np.random.randn(2, 3)), "x": ["a", "b"]}) From 99594051ef591f12b4b78a8b24136da46d0bf28f Mon Sep 17 00:00:00 2001 From: Spencer Clark Date: Wed, 15 Jan 2020 10:22:29 -0500 Subject: [PATCH 15/29] Use encoding['dtype'] over data.dtype when possible within CFMaskCoder.encode (#3652) * Use encoding['dtype'] over data.dtype when possible * Add what's new entry * Fix typo in what's new Co-authored-by: Deepak Cherian --- doc/whats-new.rst | 3 +++ xarray/coding/variables.py | 5 +++-- xarray/tests/test_coding.py | 36 +++++++++++++++++++++++++++--------- 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1ad344a208d..f8d13098250 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -69,6 +69,9 @@ Bug fixes By `Tom Augspurger `_. - Ensure :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` issue the correct error when ``q`` is out of bounds (:issue:`3634`) by `Mathias Hauser `_. +- Fix regression in xarray 0.14.1 that prevented encoding times with certain + ``dtype``, ``_FillValue``, and ``missing_value`` encodings (:issue:`3624`). + By `Spencer Clark `_ - Raise an error when trying to use :py:meth:`Dataset.rename_dims` to rename to an existing name (:issue:`3438`, :pull:`3645`) By `Justus Magin `_. diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 2b5f87ab0cd..28ead397461 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -148,6 +148,7 @@ class CFMaskCoder(VariableCoder): def encode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_encoding(variable) + dtype = np.dtype(encoding.get("dtype", data.dtype)) fv = encoding.get("_FillValue") mv = encoding.get("missing_value") @@ -162,14 +163,14 @@ def encode(self, variable, name=None): if fv is not None: # Ensure _FillValue is cast to same dtype as data's - encoding["_FillValue"] = data.dtype.type(fv) + encoding["_FillValue"] = dtype.type(fv) fill_value = pop_to(encoding, attrs, "_FillValue", name=name) if not pd.isnull(fill_value): data = duck_array_ops.fillna(data, fill_value) if mv is not None: # Ensure missing_value is cast to same dtype as data's - encoding["missing_value"] = data.dtype.type(mv) + encoding["missing_value"] = dtype.type(mv) fill_value = pop_to(encoding, attrs, "missing_value", name=name) if not pd.isnull(fill_value) and fv is None: data = duck_array_ops.fillna(data, fill_value) diff --git a/xarray/tests/test_coding.py b/xarray/tests/test_coding.py index 3e0474e7b60..0f191049284 100644 --- a/xarray/tests/test_coding.py +++ b/xarray/tests/test_coding.py @@ -1,10 +1,12 @@ from contextlib import suppress import numpy as np +import pandas as pd import pytest import xarray as xr from xarray.coding import variables +from xarray.conventions import decode_cf_variable, encode_cf_variable from . import assert_equal, assert_identical, requires_dask @@ -20,20 +22,36 @@ def test_CFMaskCoder_decode(): assert_identical(expected, encoded) -def test_CFMaskCoder_encode_missing_fill_values_conflict(): - original = xr.Variable( - ("x",), - [0.0, -1.0, 1.0], - encoding={"_FillValue": np.float32(1e20), "missing_value": np.float64(1e20)}, - ) - coder = variables.CFMaskCoder() - encoded = coder.encode(original) +encoding_with_dtype = { + "dtype": np.dtype("float64"), + "_FillValue": np.float32(1e20), + "missing_value": np.float64(1e20), +} +encoding_without_dtype = { + "_FillValue": np.float32(1e20), + "missing_value": np.float64(1e20), +} +CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS = { + "numeric-with-dtype": ([0.0, -1.0, 1.0], encoding_with_dtype), + "numeric-without-dtype": ([0.0, -1.0, 1.0], encoding_without_dtype), + "times-with-dtype": (pd.date_range("2000", periods=3), encoding_with_dtype), +} + + +@pytest.mark.parametrize( + ("data", "encoding"), + CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS.values(), + ids=list(CFMASKCODER_ENCODE_DTYPE_CONFLICT_TESTS.keys()), +) +def test_CFMaskCoder_encode_missing_fill_values_conflict(data, encoding): + original = xr.Variable(("x",), data, encoding=encoding) + encoded = encode_cf_variable(original) assert encoded.dtype == encoded.attrs["missing_value"].dtype assert encoded.dtype == encoded.attrs["_FillValue"].dtype with pytest.warns(variables.SerializationWarning): - roundtripped = coder.decode(coder.encode(original)) + roundtripped = decode_cf_variable("foo", encoded) assert_identical(roundtripped, original) From f8386bed945ec5e586baf906548903da93149258 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 15 Jan 2020 08:25:56 -0700 Subject: [PATCH 16/29] Add an example notebook using apply_ufunc to vectorize 1D functions (#3629) * apply_ufunc vectorize 1D function example * Small udpates. * Review + strip output. * Minor updates. --- ci/requirements/doc.yml | 1 + doc/examples.rst | 7 + doc/examples/apply_ufunc_vectorize_1d.ipynb | 736 ++++++++++++++++++++ doc/whats-new.rst | 3 + 4 files changed, 747 insertions(+) create mode 100644 doc/examples/apply_ufunc_vectorize_1d.ipynb diff --git a/ci/requirements/doc.yml b/ci/requirements/doc.yml index a0c27a30b01..1b479ad4ca2 100644 --- a/ci/requirements/doc.yml +++ b/ci/requirements/doc.yml @@ -14,6 +14,7 @@ dependencies: - jupyter_client - nbsphinx - netcdf4 + - numba - numpy - numpydoc - pandas<0.25 # Hack around https://github.com/pydata/xarray/issues/3369 diff --git a/doc/examples.rst b/doc/examples.rst index ce56102cc9d..3067ca824be 100644 --- a/doc/examples.rst +++ b/doc/examples.rst @@ -10,3 +10,10 @@ Examples examples/visualization_gallery examples/ROMS_ocean_model examples/ERA5-GRIB-example + +Using apply_ufunc +------------------ +.. toctree:: + :maxdepth: 2 + + examples/apply_ufunc_vectorize_1d diff --git a/doc/examples/apply_ufunc_vectorize_1d.ipynb b/doc/examples/apply_ufunc_vectorize_1d.ipynb new file mode 100644 index 00000000000..6d18d48fdb5 --- /dev/null +++ b/doc/examples/apply_ufunc_vectorize_1d.ipynb @@ -0,0 +1,736 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Applying unvectorized functions with `apply_ufunc`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This example will illustrate how to conveniently apply an unvectorized function `func` to xarray objects using `apply_ufunc`. `func` expects 1D numpy arrays and returns a 1D numpy array. Our goal is to coveniently apply this function along a dimension of xarray objects that may or may not wrap dask arrays with a signature.\n", + "\n", + "We will illustrate this using `np.interp`: \n", + "\n", + " Signature: np.interp(x, xp, fp, left=None, right=None, period=None)\n", + " Docstring:\n", + " One-dimensional linear interpolation.\n", + "\n", + " Returns the one-dimensional piecewise linear interpolant to a function\n", + " with given discrete data points (`xp`, `fp`), evaluated at `x`.\n", + "\n", + "and write an `xr_interp` function with signature\n", + "\n", + " xr_interp(xarray_object, dimension_name, new_coordinate_to_interpolate_to)\n", + "\n", + "### Load data\n", + "\n", + "First lets load an example dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:45:51.659160Z", + "start_time": "2020-01-15T14:45:50.528742Z" + } + }, + "outputs": [], + "source": [ + "import xarray as xr\n", + "import numpy as np\n", + "\n", + "xr.set_options(display_style=\"html\") # fancy HTML repr\n", + "\n", + "air = (\n", + " xr.tutorial.load_dataset(\"air_temperature\")\n", + " .air.sortby(\"lat\") # np.interp needs coordinate in ascending order\n", + " .isel(time=slice(4), lon=slice(3))\n", + ") # choose a small subset for convenience\n", + "air" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The function we will apply is `np.interp` which expects 1D numpy arrays. This functionality is already implemented in xarray so we use that capability to make sure we are not making mistakes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:45:55.431708Z", + "start_time": "2020-01-15T14:45:55.104701Z" + } + }, + "outputs": [], + "source": [ + "newlat = np.linspace(15, 75, 100)\n", + "air.interp(lat=newlat)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's define a function that works with one vector of data along `lat` at a time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:45:57.889496Z", + "start_time": "2020-01-15T14:45:57.792269Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = interp1d_np(air.isel(time=0, lon=0), air.lat, newlat)\n", + "expected = air.interp(lat=newlat)\n", + "\n", + "# no errors are raised if values are equal to within floating point precision\n", + "np.testing.assert_allclose(expected.isel(time=0, lon=0).values, interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### No errors are raised so our interpolation is working.\n", + "\n", + "This function consumes and returns numpy arrays, which means we need to do a lot of work to convert the result back to an xarray object with meaningful metadata. This is where `apply_ufunc` is very useful.\n", + "\n", + "### `apply_ufunc`\n", + "\n", + " Apply a vectorized function for unlabeled arrays on xarray objects.\n", + "\n", + " The function will be mapped over the data variable(s) of the input arguments using \n", + " xarray’s standard rules for labeled computation, including alignment, broadcasting, \n", + " looping over GroupBy/Dataset variables, and merging of coordinates.\n", + " \n", + "`apply_ufunc` has many capabilities but for simplicity this example will focus on the common task of vectorizing 1D functions over nD xarray objects. We will iteratively build up the right set of arguments to `apply_ufunc` and read through many error messages in doing so." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:45:59.768626Z", + "start_time": "2020-01-15T14:45:59.543808Z" + } + }, + "outputs": [], + "source": [ + "xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`apply_ufunc` needs to know a lot of information about what our function does so that it can reconstruct the outputs. In this case, the size of dimension lat has changed and we need to explicitly specify that this will happen. xarray helpfully tells us that we need to specify the kwarg `exclude_dims`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### `exclude_dims`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "```\n", + "exclude_dims : set, optional\n", + " Core dimensions on the inputs to exclude from alignment and\n", + " broadcasting entirely. Any input coordinates along these dimensions\n", + " will be dropped. Each excluded dimension must also appear in\n", + " ``input_core_dims`` for at least one argument. Only dimensions listed\n", + " here are allowed to change size between input and output objects.\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:02.187012Z", + "start_time": "2020-01-15T14:46:02.105563Z" + } + }, + "outputs": [], + "source": [ + "xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Core dimensions\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Core dimensions are central to using `apply_ufunc`. In our case, our function expects to receive a 1D vector along `lat` — this is the dimension that is \"core\" to the function's functionality. Multiple core dimensions are possible. `apply_ufunc` needs to know which dimensions of each variable are core dimensions.\n", + "\n", + " input_core_dims : Sequence[Sequence], optional\n", + " List of the same length as ``args`` giving the list of core dimensions\n", + " on each input argument that should not be broadcast. By default, we\n", + " assume there are no core dimensions on any input arguments.\n", + "\n", + " For example, ``input_core_dims=[[], ['time']]`` indicates that all\n", + " dimensions on the first argument and all dimensions other than 'time'\n", + " on the second argument should be broadcast.\n", + "\n", + " Core dimensions are automatically moved to the last axes of input\n", + " variables before applying ``func``, which facilitates using NumPy style\n", + " generalized ufuncs [2]_.\n", + " \n", + " output_core_dims : List[tuple], optional\n", + " List of the same length as the number of output arguments from\n", + " ``func``, giving the list of core dimensions on each output that were\n", + " not broadcast on the inputs. By default, we assume that ``func``\n", + " outputs exactly one array, with axes corresponding to each broadcast\n", + " dimension.\n", + "\n", + " Core dimensions are assumed to appear as the last dimensions of each\n", + " output in the provided order.\n", + " \n", + "Next we specify `\"lat\"` as `input_core_dims` on both `air` and `air.lat`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:05.031672Z", + "start_time": "2020-01-15T14:46:04.947588Z" + } + }, + "outputs": [], + "source": [ + "xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []],\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "xarray is telling us that it expected to receive back a numpy array with 0 dimensions but instead received an array with 1 dimension corresponding to `newlat`. We can fix this by specifying `output_core_dims`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:09.325218Z", + "start_time": "2020-01-15T14:46:09.303020Z" + } + }, + "outputs": [], + "source": [ + "xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []], # list with one entry per arg\n", + " output_core_dims=[[\"lat\"]],\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally we get some output! Let's check that this is right\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:11.295440Z", + "start_time": "2020-01-15T14:46:11.226553Z" + } + }, + "outputs": [], + "source": [ + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(time=0, lon=0), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []], # list with one entry per arg\n", + " output_core_dims=[[\"lat\"]],\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected.isel(time=0, lon=0), interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "No errors are raised so it is right!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Vectorization with `np.vectorize`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now our function currently only works on one vector of data which is not so useful given our 3D dataset.\n", + "Let's try passing the whole dataset. We add a `print` statement so we can see what our function receives." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:13.808646Z", + "start_time": "2020-01-15T14:46:13.680098Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " print(f\"data: {data.shape} | x: {x.shape} | xi: {xi.shape}\")\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.isel(\n", + " lon=slice(3), time=slice(4)\n", + " ), # now arguments in the order expected by 'interp1_np'\n", + " air.lat,\n", + " newlat,\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []], # list with one entry per arg\n", + " output_core_dims=[[\"lat\"]],\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + ")\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected.isel(time=0, lon=0), interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "That's a hard-to-interpret error but our `print` call helpfully printed the shapes of the input data: \n", + "\n", + " data: (10, 53, 25) | x: (25,) | xi: (100,)\n", + "\n", + "We see that `air` has been passed as a 3D numpy array which is not what `np.interp` expects. Instead we want loop over all combinations of `lon` and `time`; and apply our function to each corresponding vector of data along `lat`.\n", + "`apply_ufunc` makes this easy by specifying `vectorize=True`:\n", + "\n", + " vectorize : bool, optional\n", + " If True, then assume ``func`` only takes arrays defined over core\n", + " dimensions as input and vectorize it automatically with\n", + " :py:func:`numpy.vectorize`. This option exists for convenience, but is\n", + " almost always slower than supplying a pre-vectorized function.\n", + " Using this option requires NumPy version 1.12 or newer.\n", + " \n", + "Also see the documentation for `np.vectorize`: https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html. Most importantly\n", + "\n", + " The vectorize function is provided primarily for convenience, not for performance. \n", + " The implementation is essentially a for loop." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:26.633233Z", + "start_time": "2020-01-15T14:46:26.515209Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " print(f\"data: {data.shape} | x: {x.shape} | xi: {xi.shape}\")\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air, # now arguments in the order expected by 'interp1_np'\n", + " air.lat, # as above\n", + " newlat, # as above\n", + " input_core_dims=[[\"lat\"], [\"lat\"], []], # list with one entry per arg\n", + " output_core_dims=[[\"lat\"]], # returned data has one dimension\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be set!\n", + " vectorize=True, # loop over non-core dims\n", + ")\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected, interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This unfortunately is another cryptic error from numpy. \n", + "\n", + "Notice that `newlat` is not an xarray object. Let's add a dimension name `new_lat` and modify the call. Note this cannot be `lat` because xarray expects dimensions to be the same size (or broadcastable) among all inputs. `output_core_dims` needs to be modified appropriately. We'll manually rename `new_lat` back to `lat` for easy checking." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:46:30.026663Z", + "start_time": "2020-01-15T14:46:29.893267Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " print(f\"data: {data.shape} | x: {x.shape} | xi: {xi.shape}\")\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air, # now arguments in the order expected by 'interp1_np'\n", + " air.lat, # as above\n", + " newlat, # as above\n", + " input_core_dims=[[\"lat\"], [\"lat\"], [\"new_lat\"]], # list with one entry per arg\n", + " output_core_dims=[[\"new_lat\"]], # returned data has one dimension\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be a set!\n", + " vectorize=True, # loop over non-core dims\n", + ")\n", + "interped = interped.rename({\"new_lat\": \"lat\"})\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(\n", + " expected.transpose(*interped.dims), interped # order of dims is different\n", + ")\n", + "interped" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that the printed input shapes are all 1D and correspond to one vector along the `lat` dimension.\n", + "\n", + "The result is now an xarray object with coordinate values copied over from `data`. This is why `apply_ufunc` is so convenient; it takes care of a lot of boilerplate necessary to apply functions that consume and produce numpy arrays to xarray objects.\n", + "\n", + "One final point: `lat` is now the *last* dimension in `interped`. This is a \"property\" of core dimensions: they are moved to the end before being sent to `interp1d_np` as was noted in the docstring for `input_core_dims`\n", + "\n", + " Core dimensions are automatically moved to the last axes of input\n", + " variables before applying ``func``, which facilitates using NumPy style\n", + " generalized ufuncs [2]_." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Parallelization with dask\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So far our function can only handle numpy arrays. A real benefit of `apply_ufunc` is the ability to easily parallelize over dask chunks _when needed_. \n", + "\n", + "We want to apply this function in a vectorized fashion over each chunk of the dask array. This is possible using dask's `blockwise` or `map_blocks`. `apply_ufunc` wraps `blockwise` and asking it to map the function over chunks using `blockwise` is as simple as specifying `dask=\"parallelized\"`. With this level of flexibility we need to provide dask with some extra information: \n", + " 1. `output_dtypes`: dtypes of all returned objects, and \n", + " 2. `output_sizes`: lengths of any new dimensions. \n", + " \n", + "Here we need to specify `output_dtypes` since `apply_ufunc` can infer the size of the new dimension `new_lat` from the argument corresponding to the third element in `input_core_dims`. Here I choose the chunk sizes to illustrate that `np.vectorize` is still applied so that our function receives 1D vectors even though the blocks are 3D." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:48:42.469341Z", + "start_time": "2020-01-15T14:48:42.344209Z" + } + }, + "outputs": [], + "source": [ + "def interp1d_np(data, x, xi):\n", + " print(f\"data: {data.shape} | x: {x.shape} | xi: {xi.shape}\")\n", + " return np.interp(xi, x, data)\n", + "\n", + "\n", + "interped = xr.apply_ufunc(\n", + " interp1d_np, # first the function\n", + " air.chunk(\n", + " {\"time\": 2, \"lon\": 2}\n", + " ), # now arguments in the order expected by 'interp1_np'\n", + " air.lat, # as above\n", + " newlat, # as above\n", + " input_core_dims=[[\"lat\"], [\"lat\"], [\"new_lat\"]], # list with one entry per arg\n", + " output_core_dims=[[\"new_lat\"]], # returned data has one dimension\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be a set!\n", + " vectorize=True, # loop over non-core dims\n", + " dask=\"parallelized\",\n", + " output_dtypes=[air.dtype], # one per output\n", + ").rename({\"new_lat\": \"lat\"})\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected.transpose(*interped.dims), interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Yay! our function is receiving 1D vectors, so we've successfully parallelized applying a 1D function over a block. If you have a distributed dashboard up, you should see computes happening as equality is checked.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### High performance vectorization: gufuncs, numba & guvectorize\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`np.vectorize` is a very convenient function but is unfortunately slow. It is only marginally faster than writing a for loop in Python and looping. A common way to get around this is to write a base interpolation function that can handle nD arrays in a compiled language like Fortran and then pass that to `apply_ufunc`.\n", + "\n", + "Another option is to use the numba package which provides a very convenient `guvectorize` decorator: https://numba.pydata.org/numba-doc/latest/user/vectorize.html#the-guvectorize-decorator\n", + "\n", + "Any decorated function gets compiled and will loop over any non-core dimension in parallel when necessary. We need to specify some extra information:\n", + "\n", + " 1. Our function cannot return a variable any more. Instead it must receive a variable (the last argument) whose contents the function will modify. So we change from `def interp1d_np(data, x, xi)` to `def interp1d_np_gufunc(data, x, xi, out)`. Our computed results must be assigned to `out`. All values of `out` must be assigned explicitly.\n", + " \n", + " 2. `guvectorize` needs to know the dtypes of the input and output. This is specified in string form as the first argument. Each element of the tuple corresponds to each argument of the function. In this case, we specify `float64` for all inputs and outputs: `\"(float64[:], float64[:], float64[:], float64[:])\"` corresponding to `data, x, xi, out`\n", + " \n", + " 3. Now we need to tell numba the size of the dimensions the function takes as inputs and returns as output i.e. core dimensions. This is done in symbolic form i.e. `data` and `x` are vectors of the same length, say `n`; `xi` and the output `out` have a different length, say `m`. So the second argument is (again as a string)\n", + " `\"(n), (n), (m) -> (m).\"` corresponding again to `data, x, xi, out`\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:48:45.267633Z", + "start_time": "2020-01-15T14:48:44.943939Z" + } + }, + "outputs": [], + "source": [ + "from numba import float64, guvectorize\n", + "\n", + "\n", + "@guvectorize(\"(float64[:], float64[:], float64[:], float64[:])\", \"(n), (n), (m) -> (m)\")\n", + "def interp1d_np_gufunc(data, x, xi, out):\n", + " # numba doesn't really like this.\n", + " # seem to support fstrings so do it the old way\n", + " print(\n", + " \"data: \" + str(data.shape) + \" | x:\" + str(x.shape) + \" | xi: \" + str(xi.shape)\n", + " )\n", + " out[:] = np.interp(xi, x, data)\n", + " # gufuncs don't return data\n", + " # instead you assign to a the last arg\n", + " # return np.interp(xi, x, data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The warnings are about object-mode compilation relating to the `print` statement. This means we don't get much speed up: https://numba.pydata.org/numba-doc/latest/user/performance-tips.html#no-python-mode-vs-object-mode. We'll keep the `print` statement temporarily to make sure that `guvectorize` acts like we want it to." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:48:54.755405Z", + "start_time": "2020-01-15T14:48:54.634724Z" + } + }, + "outputs": [], + "source": [ + "interped = xr.apply_ufunc(\n", + " interp1d_np_gufunc, # first the function\n", + " air.chunk(\n", + " {\"time\": 2, \"lon\": 2}\n", + " ), # now arguments in the order expected by 'interp1_np'\n", + " air.lat, # as above\n", + " newlat, # as above\n", + " input_core_dims=[[\"lat\"], [\"lat\"], [\"new_lat\"]], # list with one entry per arg\n", + " output_core_dims=[[\"new_lat\"]], # returned data has one dimension\n", + " exclude_dims=set((\"lat\",)), # dimensions allowed to change size. Must be a set!\n", + " # vectorize=True, # not needed since numba takes care of vectorizing\n", + " dask=\"parallelized\",\n", + " output_dtypes=[air.dtype], # one per output\n", + ").rename({\"new_lat\": \"lat\"})\n", + "interped[\"lat\"] = newlat # need to add this manually\n", + "xr.testing.assert_allclose(expected.transpose(*interped.dims), interped)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Yay! Our function is receiving 1D vectors and is working automatically with dask arrays. Finally let's comment out the print line and wrap everything up in a nice reusable function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2020-01-15T14:49:28.667528Z", + "start_time": "2020-01-15T14:49:28.103914Z" + } + }, + "outputs": [], + "source": [ + "from numba import float64, guvectorize\n", + "\n", + "\n", + "@guvectorize(\n", + " \"(float64[:], float64[:], float64[:], float64[:])\",\n", + " \"(n), (n), (m) -> (m)\",\n", + " nopython=True,\n", + ")\n", + "def interp1d_np_gufunc(data, x, xi, out):\n", + " out[:] = np.interp(xi, x, data)\n", + "\n", + "\n", + "def xr_interp(data, dim, newdim):\n", + "\n", + " interped = xr.apply_ufunc(\n", + " interp1d_np_gufunc, # first the function\n", + " data, # now arguments in the order expected by 'interp1_np'\n", + " data[dim], # as above\n", + " newdim, # as above\n", + " input_core_dims=[[dim], [dim], [\"__newdim__\"]], # list with one entry per arg\n", + " output_core_dims=[[\"__newdim__\"]], # returned data has one dimension\n", + " exclude_dims=set((dim,)), # dimensions allowed to change size. Must be a set!\n", + " # vectorize=True, # not needed since numba takes care of vectorizing\n", + " dask=\"parallelized\",\n", + " output_dtypes=[data.dtype], # one per output; could also be float or np.dtype(\"float64\")\n", + " ).rename({\"__newdim__\": dim})\n", + " interped[dim] = newdim # need to add this manually\n", + "\n", + " return interped\n", + "\n", + "\n", + "xr.testing.assert_allclose(\n", + " expected.transpose(*interped.dims),\n", + " xr_interp(air.chunk({\"time\": 2, \"lon\": 2}), \"lat\", newlat),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This technique is generalizable to any 1D function." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + }, + "nbsphinx": { + "allow_errors": true + }, + "org": null, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": false, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f8d13098250..17c679d27b5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -100,6 +100,9 @@ Documentation - Added examples for :py:meth:`DataArray.quantile`, :py:meth:`Dataset.quantile` and ``GroupBy.quantile``. (:pull:`3576`) By `Justus Magin `_. +- Add new :py:func:`apply_ufunc` example notebook demonstrating vectorization of a 1D + function using dask and numba. + By `Deepak Cherian `_. - Added example for :py:func:`~xarray.map_blocks`. (:pull:`3667`) By `Riley X. Brady `_. From 3955c37959a81d6e83c6253962d374ee2f5019b6 Mon Sep 17 00:00:00 2001 From: keewis Date: Wed, 15 Jan 2020 17:53:01 +0100 Subject: [PATCH 17/29] Tests for variables with units (#3654) * add tests for variable by inheriting from VariableSubclassobjects * make sure the utility functions work with variables * add additional tests for aggregation and numpy methods * don't assume variables have a name attribute * properly merge the used arguments * properly attach to non-xarray objects * add a test function for searchsorted and item * add tests for missing value detection methods * add tests for comparisons * don't try to check missing value behaviour for int arrays * xfail the compatible unit tests * use an additional check since identical is not sufficient right now * check for compatibility by comparing the dimensionality of units * add tests for fillna * add tests for isel * add tests for 1d math * add tests for broadcast_equals * add tests for masking * add tests for squeeze * add tests for coarsen, quantile, roll, round and shift * add tests for searchsorted * add tests for rolling_window * add tests for transpose * add tests for stack and unstack * add tests for set_dims * add tests for concat, copy and no_conflicts * add tests for reduce * add tests for pad_with_fill_value * all and any have been implemented by pint * remove a unnecessary call to np.array * mark the unrelated failures as xfail * remove a debug print * document the need for force_ndarray * make is_compatible a little bit more robust * clean up is_compatible * add docstrings explaining the use of the method and function classes * put the unit comparisons into a function * actually squeeze the dimensions separately --- xarray/tests/test_units.py | 854 ++++++++++++++++++++++++++++++++++++- 1 file changed, 844 insertions(+), 10 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index f8a8a259c1f..2cb1550c088 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -7,12 +7,15 @@ import xarray as xr from xarray.core import formatting from xarray.core.npcompat import IS_NEP18_ACTIVE +from .test_variable import VariableSubclassobjects pint = pytest.importorskip("pint") DimensionalityError = pint.errors.DimensionalityError -unit_registry = pint.UnitRegistry() +# make sure scalars are converted to 0d arrays so quantities can +# always be treated like ndarrays +unit_registry = pint.UnitRegistry(force_ndarray=True) Quantity = unit_registry.Quantity pytestmark = [ @@ -28,6 +31,25 @@ ] +def is_compatible(unit1, unit2): + def dimensionality(obj): + if isinstance(obj, (unit_registry.Quantity, unit_registry.Unit)): + unit_like = obj + else: + unit_like = unit_registry.dimensionless + + return unit_like.dimensionality + + return dimensionality(unit1) == dimensionality(unit2) + + +def compatible_mappings(first, second): + return { + key: is_compatible(unit1, unit2) + for key, (unit1, unit2) in merge_mappings(first, second) + } + + def array_extract_units(obj): if isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)): obj = obj.data @@ -113,6 +135,10 @@ def extract_units(obj): } units = {**vars_units, **coords_units} + elif isinstance(obj, xr.Variable): + vars_units = {None: array_extract_units(obj.data)} + + units = {**vars_units} elif isinstance(obj, Quantity): vars_units = {None: array_extract_units(obj)} @@ -148,6 +174,9 @@ def strip_units(obj): new_obj = xr.DataArray( name=strip_units(obj.name), data=data, coords=coords, dims=obj.dims ) + elif isinstance(obj, xr.Variable): + data = array_strip_units(obj.data) + new_obj = obj.copy(data=data) elif isinstance(obj, unit_registry.Quantity): new_obj = obj.magnitude elif isinstance(obj, (list, tuple)): @@ -159,8 +188,9 @@ def strip_units(obj): def attach_units(obj, units): - if not isinstance(obj, (xr.DataArray, xr.Dataset)): - return array_attach_units(obj, units.get("data", 1)) + if not isinstance(obj, (xr.DataArray, xr.Dataset, xr.Variable)): + units = units.get("data", None) or units.get(None, None) or 1 + return array_attach_units(obj, units) if isinstance(obj, xr.Dataset): data_vars = { @@ -172,7 +202,7 @@ def attach_units(obj, units): } new_obj = xr.Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs) - else: + elif isinstance(obj, xr.DataArray): # try the array name, "data" and None, then fall back to dimensionless data_units = ( units.get(obj.name, None) @@ -198,6 +228,11 @@ def attach_units(obj, units): new_obj = xr.DataArray( name=obj.name, data=data, coords=coords, attrs=attrs, dims=dims ) + else: + data_units = units.get("data", None) or units.get(None, None) or 1 + + data = array_attach_units(obj.data, data_units) + new_obj = obj.copy(data=data) return new_obj @@ -243,6 +278,10 @@ def convert_units(obj, to): return new_obj +def assert_units_equal(a, b): + assert extract_units(a) == extract_units(b) + + def assert_equal_with_units(a, b): # works like xr.testing.assert_equal, but also explicitly checks units # so, it is more like assert_identical @@ -282,7 +321,27 @@ def dtype(request): return request.param +def merge_mappings(*mappings): + for key in set(mappings[0]).intersection(*mappings[1:]): + yield key, tuple(m[key] for m in mappings) + + +def merge_args(default_args, new_args): + from itertools import zip_longest + + fill_value = object() + return [ + second if second is not fill_value else first + for first, second in zip_longest(default_args, new_args, fillvalue=fill_value) + ] + + class method: + """ wrapper class to help with passing methods via parametrize + + This is works a bit similar to using `partial(Class.method, arg, kwarg)` + """ + def __init__(self, name, *args, **kwargs): self.name = name self.args = args @@ -292,7 +351,7 @@ def __call__(self, obj, *args, **kwargs): from collections.abc import Callable from functools import partial - all_args = list(self.args) + list(args) + all_args = merge_args(self.args, args) all_kwargs = {**self.kwargs, **kwargs} func = getattr(obj, self.name, None) @@ -301,7 +360,7 @@ def __call__(self, obj, *args, **kwargs): if not isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)): numpy_func = getattr(np, self.name) func = partial(numpy_func, obj) - # remove typical xr args like "dim" + # remove typical xarray args like "dim" exclude_kwargs = ("dim", "dims") all_kwargs = { key: value @@ -318,12 +377,21 @@ def __repr__(self): class function: - def __init__(self, name_or_function, *args, **kwargs): + """ wrapper class for numpy functions + + Same as method, but the name is used for referencing numpy functions + """ + + def __init__(self, name_or_function, *args, function_label=None, **kwargs): if callable(name_or_function): - self.name = name_or_function.__name__ + self.name = ( + function_label + if function_label is not None + else name_or_function.__name__ + ) self.func = name_or_function else: - self.name = name_or_function + self.name = name_or_function if function_label is None else function_label self.func = getattr(np, name_or_function) if self.func is None: raise AttributeError( @@ -334,7 +402,7 @@ def __init__(self, name_or_function, *args, **kwargs): self.kwargs = kwargs def __call__(self, *args, **kwargs): - all_args = list(self.args) + list(args) + all_args = merge_args(self.args, args) all_kwargs = {**self.kwargs, **kwargs} return self.func(*all_args, **all_kwargs) @@ -343,6 +411,7 @@ def __repr__(self): return f"function_{self.name}" +@pytest.mark.xfail(reason="test bug: apply_ufunc should not be called that way") def test_apply_ufunc_dataarray(dtype): func = function( xr.apply_ufunc, np.mean, input_core_dims=[["x"]], kwargs={"axis": -1} @@ -358,6 +427,7 @@ def test_apply_ufunc_dataarray(dtype): assert_equal_with_units(expected, actual) +@pytest.mark.xfail(reason="test bug: apply_ufunc should not be called that way") def test_apply_ufunc_dataset(dtype): func = function( xr.apply_ufunc, np.mean, input_core_dims=[["x"]], kwargs={"axis": -1} @@ -1245,6 +1315,769 @@ def test_dot_dataarray(dtype): assert_equal_with_units(expected, actual) +def delete_attrs(*to_delete): + def wrapper(cls): + for item in to_delete: + setattr(cls, item, None) + + return cls + + return wrapper + + +@delete_attrs( + "test_getitem_with_mask", + "test_getitem_with_mask_nd_indexer", + "test_index_0d_string", + "test_index_0d_datetime", + "test_index_0d_timedelta64", + "test_0d_time_data", + "test_datetime64_conversion", + "test_timedelta64_conversion", + "test_pandas_period_index", + "test_1d_math", + "test_1d_reduce", + "test_array_interface", + "test___array__", + "test_copy_index", + "test_concat_number_strings", + "test_concat_fixed_len_str", + "test_concat_mixed_dtypes", + "test_pandas_datetime64_with_tz", + "test_pandas_data", + "test_multiindex", +) +class TestVariable(VariableSubclassobjects): + @staticmethod + def cls(dims, data, *args, **kwargs): + return xr.Variable( + dims, unit_registry.Quantity(data, unit_registry.m), *args, **kwargs + ) + + @pytest.mark.parametrize( + "func", + ( + method("all"), + method("any"), + method("argmax"), + method("argmin"), + method("argsort"), + method("cumprod"), + method("cumsum"), + method("max"), + method("mean"), + method("median"), + method("min"), + pytest.param( + method("prod"), + marks=pytest.mark.xfail(reason="not implemented by pint"), + ), + method("std"), + method("sum"), + method("var"), + ), + ids=repr, + ) + def test_aggregation(self, func, dtype): + array = np.linspace(0, 1, 10).astype(dtype) * ( + unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless + ) + variable = xr.Variable("x", array) + + units = extract_units(func(array)) + expected = attach_units(func(strip_units(variable)), units) + actual = func(variable) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "func", + ( + method("astype", np.float32), + method("conj"), + method("conjugate"), + method("clip", min=2, max=7), + ), + ids=repr, + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_numpy_methods(self, func, unit, error, dtype): + array = np.linspace(0, 1, 10).astype(dtype) * unit_registry.m + variable = xr.Variable("x", array) + + args = [ + item * unit if isinstance(item, (int, float, list)) else item + for item in func.args + ] + kwargs = { + key: value * unit if isinstance(value, (int, float, list)) else value + for key, value in func.kwargs.items() + } + + if error is not None and func.name in ("searchsorted", "clip"): + with pytest.raises(error): + func(variable, *args, **kwargs) + + return + + converted_args = [ + strip_units(convert_units(item, {None: unit_registry.m})) for item in args + ] + converted_kwargs = { + key: strip_units(convert_units(value, {None: unit_registry.m})) + for key, value in kwargs.items() + } + + units = extract_units(func(array, *args, **kwargs)) + expected = attach_units( + func(strip_units(variable), *converted_args, **converted_kwargs), units + ) + actual = func(variable, *args, **kwargs) + + assert_units_equal(expected, actual) + xr.testing.assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "func", (method("item", 5), method("searchsorted", 5)), ids=repr + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_raw_numpy_methods(self, func, unit, error, dtype): + array = np.linspace(0, 1, 10).astype(dtype) * unit_registry.m + variable = xr.Variable("x", array) + + args = [ + item * unit + if isinstance(item, (int, float, list)) and func.name != "item" + else item + for item in func.args + ] + kwargs = { + key: value * unit + if isinstance(value, (int, float, list)) and func.name != "item" + else value + for key, value in func.kwargs.items() + } + + if error is not None and func.name != "item": + with pytest.raises(error): + func(variable, *args, **kwargs) + + return + + converted_args = [ + strip_units(convert_units(item, {None: unit_registry.m})) + if func.name != "item" + else item + for item in args + ] + converted_kwargs = { + key: strip_units(convert_units(value, {None: unit_registry.m})) + if func.name != "item" + else value + for key, value in kwargs.items() + } + + units = extract_units(func(array, *args, **kwargs)) + expected = attach_units( + func(strip_units(variable), *converted_args, **converted_kwargs), units + ) + actual = func(variable, *args, **kwargs) + + assert_units_equal(expected, actual) + np.testing.assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "func", (method("isnull"), method("notnull"), method("count")), ids=repr + ) + def test_missing_value_detection(self, func): + array = ( + np.array( + [ + [1.4, 2.3, np.nan, 7.2], + [np.nan, 9.7, np.nan, np.nan], + [2.1, np.nan, np.nan, 4.6], + [9.9, np.nan, 7.2, 9.1], + ] + ) + * unit_registry.degK + ) + variable = xr.Variable(("x", "y"), array) + + expected = func(strip_units(variable)) + actual = func(variable) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param( + 1, + DimensionalityError, + id="no_unit", + marks=pytest.mark.xfail(reason="uses 0 as a replacement"), + ), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param( + unit_registry.cm, + None, + id="compatible_unit", + marks=pytest.mark.xfail(reason="converts to fill value's unit"), + ), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_missing_value_fillna(self, unit, error): + value = 0 + array = ( + np.array( + [ + [1.4, 2.3, np.nan, 7.2], + [np.nan, 9.7, np.nan, np.nan], + [2.1, np.nan, np.nan, 4.6], + [9.9, np.nan, 7.2, 9.1], + ] + ) + * unit_registry.m + ) + variable = xr.Variable(("x", "y"), array) + + fill_value = value * unit + + if error is not None: + with pytest.raises(error): + variable.fillna(value=fill_value) + + return + + expected = attach_units( + strip_units(variable).fillna( + value=fill_value.to(unit_registry.m).magnitude + ), + extract_units(variable), + ) + actual = variable.fillna(value=fill_value) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param( + unit_registry.cm, + id="compatible_unit", + marks=pytest.mark.xfail( + reason="checking for identical units does not work properly, yet" + ), + ), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "convert_data", + ( + pytest.param(False, id="no_conversion"), + pytest.param(True, id="with_conversion"), + ), + ) + @pytest.mark.parametrize("func", (method("equals"), method("identical")), ids=repr) + def test_comparisons(self, func, unit, convert_data, dtype): + array = np.linspace(0, 1, 9).astype(dtype) + quantity1 = array * unit_registry.m + variable = xr.Variable("x", quantity1) + + if convert_data and is_compatible(unit_registry.m, unit): + quantity2 = convert_units(array * unit_registry.m, {None: unit}) + else: + quantity2 = array * unit + other = xr.Variable("x", quantity2) + + expected = func( + strip_units(variable), + strip_units( + convert_units(other, extract_units(variable)) + if is_compatible(unit_registry.m, unit) + else other + ), + ) + if func.name == "identical": + expected &= extract_units(variable) == extract_units(other) + else: + expected &= all( + compatible_mappings( + extract_units(variable), extract_units(other) + ).values() + ) + + actual = func(variable, other) + + assert expected == actual + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + def test_broadcast_equals(self, unit, dtype): + base_unit = unit_registry.m + left_array = np.ones(shape=(2, 2), dtype=dtype) * base_unit + value = ( + (1 * base_unit).to(unit).magnitude if is_compatible(unit, base_unit) else 1 + ) + right_array = np.full(shape=(2,), fill_value=value, dtype=dtype) * unit + + left = xr.Variable(("x", "y"), left_array) + right = xr.Variable("x", right_array) + + units = { + **extract_units(left), + **({} if is_compatible(unit, base_unit) else {None: None}), + } + expected = strip_units(left).broadcast_equals( + strip_units(convert_units(right, units)) + ) & is_compatible(unit, base_unit) + actual = left.broadcast_equals(right) + + assert expected == actual + + @pytest.mark.parametrize( + "indices", + ( + pytest.param(4, id="single index"), + pytest.param([5, 2, 9, 1], id="multiple indices"), + ), + ) + def test_isel(self, indices, dtype): + array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.s + variable = xr.Variable("x", array) + + expected = attach_units( + strip_units(variable).isel(x=indices), extract_units(variable) + ) + actual = variable.isel(x=indices) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "func", + ( + function(lambda x, *_: +x, function_label="unary_plus"), + function(lambda x, *_: -x, function_label="unary_minus"), + function(lambda x, *_: abs(x), function_label="absolute"), + function(lambda x, y: x + y, function_label="sum"), + function(lambda x, y: y + x, function_label="commutative_sum"), + function(lambda x, y: x * y, function_label="product"), + function(lambda x, y: y * x, function_label="commutative_product"), + ), + ids=repr, + ) + def test_1d_math(self, func, unit, error, dtype): + base_unit = unit_registry.m + array = np.arange(5).astype(dtype) * base_unit + variable = xr.Variable("x", array) + + values = np.ones(5) + y = values * unit + + if error is not None and func.name in ("sum", "commutative_sum"): + with pytest.raises(error): + func(variable, y) + + return + + units = extract_units(func(array, y)) + if all(compatible_mappings(units, extract_units(y)).values()): + converted_y = convert_units(y, units) + else: + converted_y = y + + if all(compatible_mappings(units, extract_units(variable)).values()): + converted_variable = convert_units(variable, units) + else: + converted_variable = variable + + expected = attach_units( + func(strip_units(converted_variable), strip_units(converted_y)), units + ) + actual = func(variable, y) + + assert_units_equal(expected, actual) + xr.testing.assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param( + unit_registry.cm, + None, + id="compatible_unit", + marks=pytest.mark.xfail( + reason="getitem_with_mask converts to the unit of other" + ), + ), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "func", (method("where"), method("_getitem_with_mask")), ids=repr + ) + def test_masking(self, func, unit, error, dtype): + base_unit = unit_registry.m + array = np.linspace(0, 5, 10).astype(dtype) * base_unit + variable = xr.Variable("x", array) + cond = np.array([True, False] * 5) + + other = -1 * unit + + if error is not None: + with pytest.raises(error): + func(variable, cond, other) + + return + + expected = attach_units( + func( + strip_units(variable), + cond, + strip_units( + convert_units( + other, + {None: base_unit} + if is_compatible(base_unit, unit) + else {None: None}, + ) + ), + ), + extract_units(variable), + ) + actual = func(variable, cond, other) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + def test_squeeze(self, dtype): + shape = (2, 1, 3, 1, 1, 2) + names = list("abcdef") + array = np.ones(shape=shape) * unit_registry.m + variable = xr.Variable(names, array) + + expected = attach_units( + strip_units(variable).squeeze(), extract_units(variable) + ) + actual = variable.squeeze() + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + names = tuple(name for name, size in zip(names, shape) if shape == 1) + for name in names: + expected = attach_units( + strip_units(variable).squeeze(dim=name), extract_units(variable) + ) + actual = variable.squeeze(dim=name) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "func", + ( + method("coarsen", windows={"y": 2}, func=np.mean), + method("quantile", q=[0.25, 0.75]), + pytest.param( + method("rank", dim="x"), + marks=pytest.mark.xfail(reason="rank not implemented for non-ndarray"), + ), + method("roll", {"x": 2}), + pytest.param( + method("rolling_window", "x", 3, "window"), + marks=pytest.mark.xfail(reason="converts to ndarray"), + ), + method("reduce", np.std, "x"), + method("round", 2), + pytest.param( + method("shift", {"x": -2}), + marks=pytest.mark.xfail( + reason="trying to concatenate ndarray to quantity" + ), + ), + method("transpose", "y", "x"), + ), + ids=repr, + ) + def test_computation(self, func, dtype): + base_unit = unit_registry.m + array = np.linspace(0, 5, 5 * 10).reshape(5, 10).astype(dtype) * base_unit + variable = xr.Variable(("x", "y"), array) + + expected = attach_units(func(strip_units(variable)), extract_units(variable)) + + actual = func(variable) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_searchsorted(self, unit, error, dtype): + base_unit = unit_registry.m + array = np.linspace(0, 5, 10).astype(dtype) * base_unit + variable = xr.Variable("x", array) + + value = 0 * unit + + if error is not None: + with pytest.raises(error): + variable.searchsorted(value) + + return + + expected = strip_units(variable).searchsorted( + strip_units(convert_units(value, {None: base_unit})) + ) + + actual = variable.searchsorted(value) + + assert_units_equal(expected, actual) + np.testing.assert_allclose(expected, actual) + + def test_stack(self, dtype): + array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + variable = xr.Variable(("x", "y"), array) + + expected = attach_units( + strip_units(variable).stack(z=("x", "y")), extract_units(variable) + ) + actual = variable.stack(z=("x", "y")) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + def test_unstack(self, dtype): + array = np.linspace(0, 5, 3 * 10).astype(dtype) * unit_registry.m + variable = xr.Variable("z", array) + + expected = attach_units( + strip_units(variable).unstack(z={"x": 3, "y": 10}), extract_units(variable) + ) + actual = variable.unstack(z={"x": 3, "y": 10}) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.xfail(reason="ignores units") + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_concat(self, unit, error, dtype): + array1 = ( + np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + ) + array2 = np.linspace(5, 10, 10 * 2).reshape(10, 2).astype(dtype) * unit + + variable = xr.Variable(("x", "y"), array1) + other = xr.Variable(("y", "z"), array2) + + if error is not None: + with pytest.raises(error): + variable.concat(other) + + return + + units = extract_units(variable) + expected = attach_units( + strip_units(variable).concat(strip_units(convert_units(other, units))), + units, + ) + actual = variable.concat(other) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + def test_set_dims(self, dtype): + array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + variable = xr.Variable(("x", "y"), array) + + dims = {"z": 6, "x": 3, "a": 1, "b": 4, "y": 10} + expected = attach_units( + strip_units(variable).set_dims(dims), extract_units(variable) + ) + actual = variable.set_dims(dims) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + def test_copy(self, dtype): + array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m + other = np.arange(10).astype(dtype) * unit_registry.s + + variable = xr.Variable("x", array) + expected = attach_units( + strip_units(variable).copy(data=strip_units(other)), extract_units(other) + ) + actual = variable.copy(data=other) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + def test_no_conflicts(self, unit, dtype): + base_unit = unit_registry.m + array1 = ( + np.array( + [ + [6.3, 0.3, 0.45], + [np.nan, 0.3, 0.3], + [3.7, np.nan, 0.2], + [9.43, 0.3, 0.7], + ] + ) + * base_unit + ) + array2 = np.array([np.nan, 0.3, np.nan]) * unit + + variable = xr.Variable(("x", "y"), array1) + other = xr.Variable("y", array2) + + expected = strip_units(variable).no_conflicts( + strip_units( + convert_units( + other, {None: base_unit if is_compatible(base_unit, unit) else None} + ) + ) + ) & is_compatible(base_unit, unit) + actual = variable.no_conflicts(other) + + assert expected == actual + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param( + 1, + DimensionalityError, + id="no_unit", + marks=pytest.mark.xfail( + reason="is not treated the same as dimensionless" + ), + ), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_pad_with_fill_value(self, unit, error, dtype): + array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + variable = xr.Variable(("x", "y"), array) + + fill_value = -100 * unit + + func = method("pad_with_fill_value", x=(2, 3), y=(1, 4)) + if error is not None: + with pytest.raises(error): + func(variable, fill_value=fill_value) + + return + + units = extract_units(variable) + expected = attach_units( + func( + strip_units(variable), + fill_value=strip_units(convert_units(fill_value, units)), + ), + units, + ) + actual = func(variable, fill_value=fill_value) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") @pytest.mark.parametrize( @@ -3539,6 +4372,7 @@ def test_stacking_stacked(self, func, dtype): assert_equal_with_units(expected, actual) + @pytest.mark.xfail(reason="does not work with quantities yet") def test_to_stacked_array(self, dtype): labels = np.arange(5).astype(dtype) * unit_registry.s arrays = {name: np.linspace(0, 1, 10) * unit_registry.m for name in labels} From 46dfb779df2663c4b012124ae2ec193cca086c35 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Fri, 17 Jan 2020 11:21:30 -0500 Subject: [PATCH 18/29] remove DataArray and Dataset constructor deprecations for 0.15 (#3560) * remove 0.15 deprecations * whatsnew --- doc/whats-new.rst | 5 +++++ xarray/core/dataarray.py | 16 +--------------- xarray/core/dataset.py | 14 +------------- 3 files changed, 7 insertions(+), 28 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 17c679d27b5..a60419f89a3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,6 +22,11 @@ v0.15.0 (unreleased) Breaking changes ~~~~~~~~~~~~~~~~ +- Remove ``compat`` and ``encoding`` kwargs from ``DataArray``, which + have been deprecated since 0.12. (:pull:`3650`). + Instead, specify the encoding when writing to disk or set + the ``encoding`` attribute directly. + By `Maximilian Roos `_ New Features ~~~~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 15db6ed468e..0e67a791834 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -267,8 +267,6 @@ def __init__( dims: Union[Hashable, Sequence[Hashable], None] = None, name: Hashable = None, attrs: Mapping = None, - # deprecated parameters - encoding=None, # internal parameters indexes: Dict[Hashable, pd.Index] = None, fastpath: bool = False, @@ -313,20 +311,10 @@ def __init__( Attributes to assign to the new instance. By default, an empty attribute dictionary is initialized. """ - if encoding is not None: - warnings.warn( - "The `encoding` argument to `DataArray` is deprecated, and . " - "will be removed in 0.15. " - "Instead, specify the encoding when writing to disk or " - "set the `encoding` attribute directly.", - FutureWarning, - stacklevel=2, - ) if fastpath: variable = data assert dims is None assert attrs is None - assert encoding is None else: # try to fill in arguments from data if they weren't supplied if coords is None: @@ -348,13 +336,11 @@ def __init__( name = getattr(data, "name", None) if attrs is None and not isinstance(data, PANDAS_TYPES): attrs = getattr(data, "attrs", None) - if encoding is None: - encoding = getattr(data, "encoding", None) data = _check_data_shape(data, coords, dims) data = as_compatible_data(data) coords, dims = _infer_coords_and_dims(data.shape, coords, dims) - variable = Variable(dims, data, attrs, encoding, fastpath=True) + variable = Variable(dims, data, attrs, fastpath=True) indexes = dict( _extract_indexes_from_coords(coords) ) # needed for to_dataset diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 79f1030fabe..4564dc80a8a 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -463,7 +463,6 @@ def __init__( data_vars: Mapping[Hashable, Any] = None, coords: Mapping[Hashable, Any] = None, attrs: Mapping[Hashable, Any] = None, - compat=None, ): """To load data from a file or file-like object, use the `open_dataset` function. @@ -513,18 +512,7 @@ def __init__( attrs : dict-like, optional Global attributes to save on this dataset. - compat : deprecated """ - if compat is not None: - warnings.warn( - "The `compat` argument to Dataset is deprecated and will be " - "removed in 0.15." - "Instead, use `merge` to control how variables are combined", - FutureWarning, - stacklevel=2, - ) - else: - compat = "broadcast_equals" # TODO(shoyer): expose indexes as a public argument in __init__ @@ -544,7 +532,7 @@ def __init__( coords = coords.variables variables, coord_names, dims, indexes = merge_data_and_coords( - data_vars, coords, compat=compat + data_vars, coords, compat="broadcast_equals" ) self._attrs = dict(attrs) if attrs is not None else None From 924c05212223defe0720c9da7cdb7583c2c5b696 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 17 Jan 2020 18:51:50 +0000 Subject: [PATCH 19/29] Bump mypy to v0.761 (#3704) --- .pre-commit-config.yaml | 2 +- ci/requirements/py36-min-all-deps.yml | 2 +- ci/requirements/py36.yml | 2 +- ci/requirements/py37-windows.yml | 2 +- ci/requirements/py37.yml | 2 +- xarray/core/computation.py | 2 +- xarray/core/dataset.py | 6 +++--- 7 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 502120cd5dc..ed62c1c256e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: hooks: - id: flake8 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.730 # Must match ci/requirements/*.yml + rev: v0.761 # Must match ci/requirements/*.yml hooks: - id: mypy # run these occasionally, ref discussion https://github.com/pydata/xarray/pull/3194 diff --git a/ci/requirements/py36-min-all-deps.yml b/ci/requirements/py36-min-all-deps.yml index 3f10a158f91..e7756172311 100644 --- a/ci/requirements/py36-min-all-deps.yml +++ b/ci/requirements/py36-min-all-deps.yml @@ -25,7 +25,7 @@ dependencies: - iris=2.2 - lxml=4.4 # Optional dep of pydap - matplotlib=3.1 - - mypy=0.730 # Must match .pre-commit-config.yaml + - mypy=0.761 # Must match .pre-commit-config.yaml - nc-time-axis=1.2 - netcdf4=1.4 - numba=0.44 diff --git a/ci/requirements/py36.yml b/ci/requirements/py36.yml index 820160b19cc..7450fafbd86 100644 --- a/ci/requirements/py36.yml +++ b/ci/requirements/py36.yml @@ -21,7 +21,7 @@ dependencies: - iris - lxml # optional dep of pydap - matplotlib - - mypy=0.730 # Must match .pre-commit-config.yaml + - mypy=0.761 # Must match .pre-commit-config.yaml - nc-time-axis - netcdf4 - numba diff --git a/ci/requirements/py37-windows.yml b/ci/requirements/py37-windows.yml index 614a3bb1fab..d9e634c74ae 100644 --- a/ci/requirements/py37-windows.yml +++ b/ci/requirements/py37-windows.yml @@ -21,7 +21,7 @@ dependencies: - iris - lxml # Optional dep of pydap - matplotlib - - mypy=0.730 # Must match .pre-commit-config.yaml + - mypy=0.761 # Must match .pre-commit-config.yaml - nc-time-axis - netcdf4 - numba diff --git a/ci/requirements/py37.yml b/ci/requirements/py37.yml index 4a7aaf7d32b..2f879e29f87 100644 --- a/ci/requirements/py37.yml +++ b/ci/requirements/py37.yml @@ -21,7 +21,7 @@ dependencies: - iris - lxml # Optional dep of pydap - matplotlib - - mypy=0.730 # Must match .pre-commit-config.yaml + - mypy=0.761 # Must match .pre-commit-config.yaml - nc-time-axis - netcdf4 - numba diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 643c1137d6c..34de5edefc5 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -304,7 +304,7 @@ def _as_variables_or_variable(arg): def _unpack_dict_tuples( result_vars: Mapping[Hashable, Tuple[Variable, ...]], num_outputs: int ) -> Tuple[Dict[Hashable, Variable], ...]: - out = tuple({} for _ in range(num_outputs)) # type: ignore + out: Tuple[Dict[Hashable, Variable], ...] = tuple({} for _ in range(num_outputs)) for name, values in result_vars.items(): for value, results_dict in zip(values, out): results_dict[name] = value diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4564dc80a8a..92a64508104 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -897,11 +897,11 @@ def _replace( if dims is not None: self._dims = dims if attrs is not _default: - self._attrs = attrs # type: ignore # FIXME need mypy 0.750 + self._attrs = attrs if indexes is not _default: - self._indexes = indexes # type: ignore # FIXME need mypy 0.750 + self._indexes = indexes if encoding is not _default: - self._encoding = encoding # type: ignore # FIXME need mypy 0.750 + self._encoding = encoding obj = self else: if variables is None: From e646dd987f414e9c1b12b106e17ba650762fa6b6 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 17 Jan 2020 20:42:24 +0000 Subject: [PATCH 20/29] hardcoded xarray.__all__ (#3703) * Fix mypy --strict * isort * mypy version bump * trivial * testing.__all__ and ufuncs.__all__ * isort * Revert mypy version bump * Revert isort --- doc/whats-new.rst | 2 ++ xarray/__init__.py | 53 ++++++++++++++++++++++++++++++++++- xarray/testing.py | 7 +++++ xarray/ufuncs.py | 70 ++++++++++++++++++++++++++++++++++++++++------ 4 files changed, 123 insertions(+), 9 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a60419f89a3..fef1b988f01 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -87,6 +87,8 @@ Bug fixes - Fix a regression in :py:meth:`Dataset.drop`: allow passing any iterable when dropping variables (:issue:`3552`, :pull:`3693`) By `Justus Magin `_. +- Fixed errors emitted by ``mypy --strict`` in modules that import xarray. + (:issue:`3695`) by `Guido Imperiale `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/__init__.py b/xarray/__init__.py index 394dd0f80bc..0b44c293b3f 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -1,5 +1,4 @@ """ isort:skip_file """ -# flake8: noqa from ._version import get_versions @@ -42,3 +41,55 @@ from . import testing from .core.common import ALL_DIMS + +# A hardcoded __all__ variable is necessary to appease +# `mypy --strict` running in projects that import xarray. +__all__ = ( + # Sub-packages + "ufuncs", + "testing", + "tutorial", + # Top-level functions + "align", + "apply_ufunc", + "as_variable", + "auto_combine", + "broadcast", + "cftime_range", + "combine_by_coords", + "combine_nested", + "concat", + "decode_cf", + "dot", + "full_like", + "load_dataarray", + "load_dataset", + "map_blocks", + "merge", + "ones_like", + "open_dataarray", + "open_dataset", + "open_mfdataset", + "open_rasterio", + "open_zarr", + "register_dataarray_accessor", + "register_dataset_accessor", + "save_mfdataset", + "set_options", + "show_versions", + "where", + "zeros_like", + # Classes + "CFTimeIndex", + "Coordinate", + "DataArray", + "Dataset", + "IndexVariable", + "Variable", + # Exceptions + "MergeError", + "SerializationWarning", + # Constants + "__version__", + "ALL_DIMS", +) diff --git a/xarray/testing.py b/xarray/testing.py index 5c3ca8a3cca..ac189f7e023 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -10,6 +10,13 @@ from xarray.core.indexes import default_indexes from xarray.core.variable import IndexVariable, Variable +__all__ = ( + "assert_allclose", + "assert_chunks_equal", + "assert_equal", + "assert_identical", +) + def _decode_string_data(data): if data.dtype.kind == "S": diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index ae2c5c574b6..8ab2b7cfe31 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -132,14 +132,68 @@ def _create_op(name): return func -__all__ = """logaddexp logaddexp2 conj exp log log2 log10 log1p expm1 sqrt - square sin cos tan arcsin arccos arctan arctan2 hypot sinh cosh - tanh arcsinh arccosh arctanh deg2rad rad2deg logical_and - logical_or logical_xor logical_not maximum minimum fmax fmin - isreal iscomplex isfinite isinf isnan signbit copysign nextafter - ldexp fmod floor ceil trunc degrees radians rint fix angle real - imag fabs sign frexp fmod - """.split() +__all__ = ( # noqa: F822 + "angle", + "arccos", + "arccosh", + "arcsin", + "arcsinh", + "arctan", + "arctan2", + "arctanh", + "ceil", + "conj", + "copysign", + "cos", + "cosh", + "deg2rad", + "degrees", + "exp", + "expm1", + "fabs", + "fix", + "floor", + "fmax", + "fmin", + "fmod", + "fmod", + "frexp", + "hypot", + "imag", + "iscomplex", + "isfinite", + "isinf", + "isnan", + "isreal", + "ldexp", + "log", + "log10", + "log1p", + "log2", + "logaddexp", + "logaddexp2", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "maximum", + "minimum", + "nextafter", + "rad2deg", + "radians", + "real", + "rint", + "sign", + "signbit", + "sin", + "sinh", + "sqrt", + "square", + "tan", + "tanh", + "trunc", +) + for name in __all__: globals()[name] = _create_op(name) From c32e58b4fff72816c6b554db51509bea6a891cdc Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 17 Jan 2020 21:00:23 +0000 Subject: [PATCH 21/29] One-off isort run (#3705) * isort run * Fix docs regression --- doc/conf.py | 2 +- properties/test_pandas_roundtrip.py | 16 ++++++++-------- xarray/core/dataset.py | 2 +- xarray/core/duck_array_ops.py | 2 +- xarray/core/parallel.py | 2 +- xarray/tests/test_accessor_dt.py | 3 +-- xarray/tests/test_units.py | 1 + 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 11abda6bb63..578f9cf550d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -25,7 +25,7 @@ os.environ["PYTHONPATH"] = str(root) sys.path.insert(0, str(root)) -import xarray +import xarray # isort:skip allowed_failures = set() diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index a8005d319d6..5fc097f1f5e 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -1,20 +1,20 @@ """ Property-based tests for roundtripping between xarray and pandas objects. """ -import pytest - -pytest.importorskip("hypothesis") - from functools import partial -import hypothesis.extra.numpy as npst -import hypothesis.extra.pandas as pdst -import hypothesis.strategies as st -from hypothesis import given import numpy as np import pandas as pd +import pytest + import xarray as xr +pytest.importorskip("hypothesis") +import hypothesis.extra.numpy as npst # isort:skip +import hypothesis.extra.pandas as pdst # isort:skip +import hypothesis.strategies as st # isort:skip +from hypothesis import given # isort:skip + numeric_dtypes = st.one_of( npst.unsigned_integer_dtypes(), npst.integer_dtypes(), npst.floating_dtypes() ) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 92a64508104..82ddc8a535f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -92,8 +92,8 @@ IndexVariable, Variable, as_variable, - broadcast_variables, assert_unique_multiindex_level_names, + broadcast_variables, ) if TYPE_CHECKING: diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 98b371ab7c3..3fa97f0b448 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -11,7 +11,7 @@ import numpy as np import pandas as pd -from . import dask_array_ops, dask_array_compat, dtypes, npcompat, nputils +from . import dask_array_compat, dask_array_ops, dtypes, npcompat, nputils from .nputils import nanfirst, nanlast from .pycompat import dask_array_type diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index e4fb5803191..facfa06b23c 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -13,8 +13,8 @@ from typing import ( Any, Callable, - Dict, DefaultDict, + Dict, Hashable, Mapping, Sequence, diff --git a/xarray/tests/test_accessor_dt.py b/xarray/tests/test_accessor_dt.py index 67ca12532c7..f178720a6e1 100644 --- a/xarray/tests/test_accessor_dt.py +++ b/xarray/tests/test_accessor_dt.py @@ -11,8 +11,7 @@ requires_cftime, requires_dask, ) - -from .test_dask import raise_if_dask_computes, assert_chunks_equal +from .test_dask import assert_chunks_equal, raise_if_dask_computes class TestDatetimeAccessor: diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 2cb1550c088..6f9128839a7 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -7,6 +7,7 @@ import xarray as xr from xarray.core import formatting from xarray.core.npcompat import IS_NEP18_ACTIVE + from .test_variable import VariableSubclassobjects pint = pytest.importorskip("pint") From 5c97641dd129a9dcc4346a137bcf32a5a982d37a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Mon, 20 Jan 2020 06:46:13 +0100 Subject: [PATCH 22/29] =?UTF-8?q?ENH:=20enable=20`H5NetCDFStore`=20to=20wo?= =?UTF-8?q?rk=20with=20already=20open=20h5netcdf.File=20a=E2=80=A6=20(#361?= =?UTF-8?q?8)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ENH: enable `H5NetCDFStore` to work with already open h5netcdf.File and h5netcdf.Group objects, add test * FIX: add `.open` method for file-like objects * FIX: reformat using black * WIP: add item to whats-new.rst * FIX: temporary fix to tackle issue #3680 * FIX: do not use private API, use find_root_and_group instead * FIX: reformat using black --- doc/whats-new.rst | 5 +++ xarray/backends/api.py | 6 ++-- xarray/backends/common.py | 2 +- xarray/backends/h5netcdf_.py | 61 ++++++++++++++++++++++++++--------- xarray/tests/test_backends.py | 23 ++++++++++++- 5 files changed, 77 insertions(+), 20 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index fef1b988f01..f9e2e9270b0 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -30,6 +30,11 @@ Breaking changes New Features ~~~~~~~~~~~~ +- Support using an existing, opened h5netcdf ``File`` with + :py:class:`~xarray.backends.H5NetCDFStore`. This permits creating an + :py:class:`~xarray.Dataset` from a h5netcdf ``File`` that has been opened + using other means (:issue:`3618`). + By `Kai MĂ¼hlbauer `_. - Implement :py:func:`median` and :py:func:`nanmedian` for dask arrays. This works by rechunking to a single chunk along all reduction axes. (:issue:`2999`). By `Deepak Cherian `_. diff --git a/xarray/backends/api.py b/xarray/backends/api.py index eea1fb9ddce..0990fc46219 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -503,7 +503,7 @@ def maybe_decode_store(store, lock=False): elif engine == "pydap": store = backends.PydapDataStore.open(filename_or_obj, **backend_kwargs) elif engine == "h5netcdf": - store = backends.H5NetCDFStore( + store = backends.H5NetCDFStore.open( filename_or_obj, group=group, lock=lock, **backend_kwargs ) elif engine == "pynio": @@ -527,7 +527,7 @@ def maybe_decode_store(store, lock=False): if engine == "scipy": store = backends.ScipyDataStore(filename_or_obj, **backend_kwargs) elif engine == "h5netcdf": - store = backends.H5NetCDFStore( + store = backends.H5NetCDFStore.open( filename_or_obj, group=group, lock=lock, **backend_kwargs ) @@ -981,7 +981,7 @@ def open_mfdataset( WRITEABLE_STORES: Dict[str, Callable] = { "netcdf4": backends.NetCDF4DataStore.open, "scipy": backends.ScipyDataStore, - "h5netcdf": backends.H5NetCDFStore, + "h5netcdf": backends.H5NetCDFStore.open, } diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 9786e0a0203..fa3ee19f542 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -34,7 +34,7 @@ def find_root_and_group(ds): """Find the root and group name of a netCDF4/h5netcdf dataset.""" hierarchy = () while ds.parent is not None: - hierarchy = (ds.name,) + hierarchy + hierarchy = (ds.name.split("/")[-1],) + hierarchy ds = ds.parent group = "/" + "/".join(hierarchy) return ds, group diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 51ed512f98b..203213afe82 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -4,9 +4,9 @@ from .. import Variable from ..core import indexing -from ..core.utils import FrozenDict -from .common import WritableCFDataStore -from .file_manager import CachingFileManager +from ..core.utils import FrozenDict, is_remote_uri +from .common import WritableCFDataStore, find_root_and_group +from .file_manager import CachingFileManager, DummyFileManager from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock from .netCDF4_ import ( BaseNetCDF4Array, @@ -69,8 +69,47 @@ class H5NetCDFStore(WritableCFDataStore): """Store for reading and writing data via h5netcdf """ - def __init__( - self, + __slots__ = ( + "autoclose", + "format", + "is_remote", + "lock", + "_filename", + "_group", + "_manager", + "_mode", + ) + + def __init__(self, manager, group=None, mode=None, lock=HDF5_LOCK, autoclose=False): + + import h5netcdf + + if isinstance(manager, (h5netcdf.File, h5netcdf.Group)): + if group is None: + root, group = find_root_and_group(manager) + else: + if not type(manager) is h5netcdf.File: + raise ValueError( + "must supply a h5netcdf.File if the group " + "argument is provided" + ) + root = manager + manager = DummyFileManager(root) + + self._manager = manager + self._group = group + self._mode = mode + self.format = None + # todo: utilizing find_root_and_group seems a bit clunky + # making filename available on h5netcdf.Group seems better + self._filename = find_root_and_group(self.ds)[0].filename + self.is_remote = is_remote_uri(self._filename) + self.lock = ensure_lock(lock) + self.autoclose = autoclose + + @classmethod + def open( + cls, filename, mode="r", format=None, @@ -86,22 +125,14 @@ def __init__( kwargs = {"invalid_netcdf": invalid_netcdf} - self._manager = CachingFileManager( - h5netcdf.File, filename, mode=mode, kwargs=kwargs - ) - if lock is None: if mode == "r": lock = HDF5_LOCK else: lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) - self._group = group - self.format = format - self._filename = filename - self._mode = mode - self.lock = ensure_lock(lock) - self.autoclose = autoclose + manager = CachingFileManager(h5netcdf.File, filename, mode=mode, kwargs=kwargs) + return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose) def _acquire(self, needs_lock=True): with self._manager.acquire_context(needs_lock) as root: diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 4fccdf2dd6c..0436ae9d244 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2182,7 +2182,7 @@ class TestH5NetCDFData(NetCDF4Base): @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: - yield backends.H5NetCDFStore(tmp_file, "w") + yield backends.H5NetCDFStore.open(tmp_file, "w") @pytest.mark.filterwarnings("ignore:complex dtypes are supported by h5py") @pytest.mark.parametrize( @@ -2345,6 +2345,27 @@ def test_dump_encodings_h5py(self): assert actual.x.encoding["compression"] == "lzf" assert actual.x.encoding["compression_opts"] is None + def test_already_open_dataset_group(self): + import h5netcdf + + with create_tmp_file() as tmp_file: + with nc4.Dataset(tmp_file, mode="w") as nc: + group = nc.createGroup("g") + v = group.createVariable("x", "int") + v[...] = 42 + + h5 = h5netcdf.File(tmp_file, mode="r") + store = backends.H5NetCDFStore(h5["g"]) + with open_dataset(store) as ds: + expected = Dataset({"x": ((), 42)}) + assert_identical(expected, ds) + + h5 = h5netcdf.File(tmp_file, mode="r") + store = backends.H5NetCDFStore(h5, group="g") + with open_dataset(store) as ds: + expected = Dataset({"x": ((), 42)}) + assert_identical(expected, ds) + @requires_h5netcdf class TestH5NetCDFFileObject(TestH5NetCDFData): From aa0f96383062c48cb17f46ef951075c0494c9c0a Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Mon, 20 Jan 2020 13:09:26 +0100 Subject: [PATCH 23/29] Feature/align in dot (#3699) * add tests * implement align * whats new * fix changes to whats new * review: fix typos --- doc/whats-new.rst | 5 ++- xarray/core/computation.py | 7 ++++ xarray/tests/test_computation.py | 54 +++++++++++++++++++++++++++++++ xarray/tests/test_dataarray.py | 55 ++++++++++++++++++++++++++++++++ 4 files changed, 120 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f9e2e9270b0..25133df8a29 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,12 +21,15 @@ v0.15.0 (unreleased) Breaking changes ~~~~~~~~~~~~~~~~ - - Remove ``compat`` and ``encoding`` kwargs from ``DataArray``, which have been deprecated since 0.12. (:pull:`3650`). Instead, specify the encoding when writing to disk or set the ``encoding`` attribute directly. By `Maximilian Roos `_ +- :py:func:`xarray.dot`, :py:meth:`DataArray.dot`, and the ``@`` operator now + use ``align="inner"`` (except when ``xarray.set_options(arithmetic_join="exact")``; + :issue:`3694`) by `Mathias Hauser `_. + New Features ~~~~~~~~~~~~ diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 34de5edefc5..eb9ca8c17fc 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -26,6 +26,7 @@ from . import duck_array_ops, utils from .alignment import deep_align from .merge import merge_coordinates_without_align +from .options import OPTIONS from .pycompat import dask_array_type from .utils import is_dict_like from .variable import Variable @@ -1175,6 +1176,11 @@ def dot(*arrays, dims=None, **kwargs): subscripts = ",".join(subscripts_list) subscripts += "->..." + "".join([dim_map[d] for d in output_core_dims[0]]) + join = OPTIONS["arithmetic_join"] + # using "inner" emulates `(a * b).sum()` for all joins (except "exact") + if join != "exact": + join = "inner" + # subscripts should be passed to np.einsum as arg, not as kwargs. We need # to construct a partial function for apply_ufunc to work. func = functools.partial(duck_array_ops.einsum, subscripts, **kwargs) @@ -1183,6 +1189,7 @@ def dot(*arrays, dims=None, **kwargs): *arrays, input_core_dims=input_core_dims, output_core_dims=output_core_dims, + join=join, dask="allowed", ) return result.transpose(*[d for d in all_dims if d in result.dims]) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 1f2634cc9b0..2d373d12095 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1043,6 +1043,60 @@ def test_dot(use_dask): pickle.loads(pickle.dumps(xr.dot(da_a))) +@pytest.mark.parametrize("use_dask", [True, False]) +def test_dot_align_coords(use_dask): + # GH 3694 + + if use_dask: + if not has_dask: + pytest.skip("test for dask.") + + a = np.arange(30 * 4).reshape(30, 4) + b = np.arange(30 * 4 * 5).reshape(30, 4, 5) + + # use partially overlapping coords + coords_a = {"a": np.arange(30), "b": np.arange(4)} + coords_b = {"a": np.arange(5, 35), "b": np.arange(1, 5)} + + da_a = xr.DataArray(a, dims=["a", "b"], coords=coords_a) + da_b = xr.DataArray(b, dims=["a", "b", "c"], coords=coords_b) + + if use_dask: + da_a = da_a.chunk({"a": 3}) + da_b = da_b.chunk({"a": 3}) + + # join="inner" is the default + actual = xr.dot(da_a, da_b) + # `dot` sums over the common dimensions of the arguments + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + actual = xr.dot(da_a, da_b, dims=...) + expected = (da_a * da_b).sum() + xr.testing.assert_allclose(expected, actual) + + with xr.set_options(arithmetic_join="exact"): + with raises_regex(ValueError, "indexes along dimension"): + xr.dot(da_a, da_b) + + # NOTE: dot always uses `join="inner"` because `(a * b).sum()` yields the same for all + # join method (except "exact") + with xr.set_options(arithmetic_join="left"): + actual = xr.dot(da_a, da_b) + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + with xr.set_options(arithmetic_join="right"): + actual = xr.dot(da_a, da_b) + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + with xr.set_options(arithmetic_join="outer"): + actual = xr.dot(da_a, da_b) + expected = (da_a * da_b).sum(["a", "b"]) + xr.testing.assert_allclose(expected, actual) + + def test_where(): cond = xr.DataArray([True, False], dims="x") actual = xr.where(cond, 1, 0) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 786eb5007a6..962be7548bc 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3973,6 +3973,43 @@ def test_dot(self): with pytest.raises(TypeError): da.dot(dm.values) + def test_dot_align_coords(self): + # GH 3694 + + x = np.linspace(-3, 3, 6) + y = np.linspace(-3, 3, 5) + z_a = range(4) + da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4)) + da = DataArray(da_vals, coords=[x, y, z_a], dims=["x", "y", "z"]) + + z_m = range(2, 6) + dm_vals = range(4) + dm = DataArray(dm_vals, coords=[z_m], dims=["z"]) + + with xr.set_options(arithmetic_join="exact"): + with raises_regex(ValueError, "indexes along dimension"): + da.dot(dm) + + da_aligned, dm_aligned = xr.align(da, dm, join="inner") + + # nd dot 1d + actual = da.dot(dm) + expected_vals = np.tensordot(da_aligned.values, dm_aligned.values, [2, 0]) + expected = DataArray(expected_vals, coords=[x, da_aligned.y], dims=["x", "y"]) + assert_equal(expected, actual) + + # multiple shared dims + dm_vals = np.arange(20 * 5 * 4).reshape((20, 5, 4)) + j = np.linspace(-3, 3, 20) + dm = DataArray(dm_vals, coords=[j, y, z_m], dims=["j", "y", "z"]) + da_aligned, dm_aligned = xr.align(da, dm, join="inner") + actual = da.dot(dm) + expected_vals = np.tensordot( + da_aligned.values, dm_aligned.values, axes=([1, 2], [1, 2]) + ) + expected = DataArray(expected_vals, coords=[x, j], dims=["x", "j"]) + assert_equal(expected, actual) + def test_matmul(self): # copied from above (could make a fixture) @@ -3986,6 +4023,24 @@ def test_matmul(self): expected = da.dot(da) assert_identical(result, expected) + def test_matmul_align_coords(self): + # GH 3694 + + x_a = np.arange(6) + x_b = np.arange(2, 8) + da_vals = np.arange(6) + da_a = DataArray(da_vals, coords=[x_a], dims=["x"]) + da_b = DataArray(da_vals, coords=[x_b], dims=["x"]) + + # only test arithmetic_join="inner" (=default) + result = da_a @ da_b + expected = da_a.dot(da_b) + assert_identical(result, expected) + + with xr.set_options(arithmetic_join="exact"): + with raises_regex(ValueError, "indexes along dimension"): + da_a @ da_b + def test_binary_op_propagate_indexes(self): # regression test for GH2227 self.dv["x"] = np.arange(self.dv.sizes["x"]) From 27a392906f12af5fa60d21bb327d739ee8411c82 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 22 Jan 2020 15:40:33 +0000 Subject: [PATCH 24/29] setuptools-scm and one-liner setup.py (#3714) * Replace versioneer with setuptools-scm. Replace setup.py with setup.cfg. Remove pytest-runner. * Python 3.8 eggs * Drop minimum pandas version * twine check --- .coveragerc | 1 - .gitattributes | 1 - .gitignore | 2 +- MANIFEST.in | 2 - ci/requirements/doc.yml | 2 +- doc/whats-new.rst | 3 + setup.cfg | 113 ++- setup.py | 110 +-- versioneer.py | 1883 --------------------------------------- xarray/__init__.py | 9 +- xarray/_version.py | 555 ------------ 11 files changed, 108 insertions(+), 2573 deletions(-) delete mode 100644 versioneer.py delete mode 100644 xarray/_version.py diff --git a/.coveragerc b/.coveragerc index 2bf6e08f5af..1bf19c310aa 100644 --- a/.coveragerc +++ b/.coveragerc @@ -5,4 +5,3 @@ omit = xarray/core/npcompat.py xarray/core/pdcompat.py xarray/core/pycompat.py - xarray/_version.py diff --git a/.gitattributes b/.gitattributes index daa5b82874e..a52f4ca283a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,2 @@ # reduce the number of merge conflicts doc/whats-new.rst merge=union -xarray/_version.py export-subst diff --git a/.gitignore b/.gitignore index ad26864221e..5f02700de37 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ doc/savefig # Packages *.egg *.egg-info +.eggs dist build eggs @@ -65,7 +66,6 @@ dask-worker-space/ # xarray specific doc/_build doc/generated -xarray/version.py xarray/tests/data/*.grib.*.idx # Sync tools diff --git a/MANIFEST.in b/MANIFEST.in index 4d5c34f622c..cbfb8c8cdca 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,6 +4,4 @@ recursive-include doc * prune doc/_build prune doc/generated global-exclude .DS_Store -include versioneer.py -include xarray/_version.py recursive-include xarray/static * diff --git a/ci/requirements/doc.yml b/ci/requirements/doc.yml index 1b479ad4ca2..a8b72dc0956 100644 --- a/ci/requirements/doc.yml +++ b/ci/requirements/doc.yml @@ -17,7 +17,7 @@ dependencies: - numba - numpy - numpydoc - - pandas<0.25 # Hack around https://github.com/pydata/xarray/issues/3369 + - pandas - rasterio - seaborn - sphinx diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 25133df8a29..67cb8e7125a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -133,6 +133,9 @@ Internal Changes - Removed internal method ``Dataset._from_vars_and_coord_names``, which was dominated by ``Dataset._construct_direct``. (:pull:`3565`) By `Maximilian Roos `_ +- Replaced versioneer with setuptools-scm. Moved contents of setup.py to setup.cfg. + Removed pytest-runner from setup.py, as per deprecation notice on the pytest-runner + project. (:pull:`3714`) by `Guido Imperiale `_ v0.14.1 (19 Nov 2019) diff --git a/setup.cfg b/setup.cfg index 21158e3b0ee..0e113c15306 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,95 @@ +[metadata] +name = xarray +author = xarray Developers +author_email = xarray@googlegroups.com +license = Apache +description = N-D labeled arrays and datasets in Python +long_description_content_type=x-rst +long_description = + **xarray** (formerly **xray**) is an open source project and Python package + that makes working with labelled multi-dimensional arrays simple, + efficient, and fun! + + xarray introduces labels in the form of dimensions, coordinates and + attributes on top of raw NumPy_-like arrays, which allows for a more + intuitive, more concise, and less error-prone developer experience. + The package includes a large and growing library of domain-agnostic functions + for advanced analytics and visualization with these data structures. + + xarray was inspired by and borrows heavily from pandas_, the popular data + analysis package focused on labelled tabular data. + It is particularly tailored to working with netCDF_ files, which were the + source of xarray's data model, and integrates tightly with dask_ for parallel + computing. + + .. _NumPy: https://www.numpy.org + .. _pandas: https://pandas.pydata.org + .. _dask: https://dask.org + .. _netCDF: https://www.unidata.ucar.edu/software/netcdf + + Why xarray? + ----------- + Multi-dimensional (a.k.a. N-dimensional, ND) arrays (sometimes called + "tensors") are an essential part of computational science. + They are encountered in a wide range of fields, including physics, astronomy, + geoscience, bioinformatics, engineering, finance, and deep learning. + In Python, NumPy_ provides the fundamental data structure and API for + working with raw ND arrays. + However, real-world datasets are usually more than just raw numbers; + they have labels which encode information about how the array values map + to locations in space, time, etc. + + xarray doesn't just keep track of labels on arrays -- it uses them to provide a + powerful and concise interface. For example: + + - Apply operations over dimensions by name: ``x.sum('time')``. + - Select values by label instead of integer location: + ``x.loc['2014-01-01']`` or ``x.sel(time='2014-01-01')``. + - Mathematical operations (e.g., ``x - y``) vectorize across multiple + dimensions (array broadcasting) based on dimension names, not shape. + - Flexible split-apply-combine operations with groupby: + ``x.groupby('time.dayofyear').mean()``. + - Database like alignment based on coordinate labels that smoothly + handles missing values: ``x, y = xr.align(x, y, join='outer')``. + - Keep track of arbitrary metadata in the form of a Python dictionary: + ``x.attrs``. + + Learn more + ---------- + - Documentation: ``_ + - Issue tracker: ``_ + - Source code: ``_ + - SciPy2015 talk: ``_ + +url = https://github.com/pydata/xarray +classifiers = + Development Status :: 5 - Production/Stable + License :: OSI Approved :: Apache Software License + Operating System :: OS Independent + Intended Audience :: Science/Research + Programming Language :: Python + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.6 + Programming Language :: Python :: 3.7 + Topic :: Scientific/Engineering + +[options] +packages = xarray +zip_safe = True +include_package_data = True +python_requires = >=3.6 +install_requires = + numpy >= 1.14 + pandas >= 0.24 +setup_requires = setuptools_scm + +[options.package_data] +xarray = + py.typed + tests/data/* + static/css/* + static/html/* + [tool:pytest] python_files=test_*.py testpaths=xarray/tests properties @@ -23,6 +115,7 @@ ignore= # line break before binary operator W503 exclude= + .eggs doc [isort] @@ -87,35 +180,19 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-seaborn.*] ignore_missing_imports = True +[mypy-setuptools] +ignore_missing_imports = True [mypy-sparse.*] ignore_missing_imports = True [mypy-toolz.*] ignore_missing_imports = True [mypy-zarr.*] ignore_missing_imports = True - -# setuptools is not typed -[mypy-setup] -ignore_errors = True -# versioneer code -[mypy-versioneer.*] -ignore_errors = True -# written by versioneer -[mypy-xarray._version] -ignore_errors = True # version spanning code is hard to type annotate (and most of this module will # be going away soon anyways) [mypy-xarray.core.pycompat] ignore_errors = True -[versioneer] -VCS = git -style = pep440 -versionfile_source = xarray/_version.py -versionfile_build = xarray/_version.py -tag_prefix = v -parentdir_prefix = xarray- - [aliases] test = pytest diff --git a/setup.py b/setup.py index cba0c74aa3a..76755a445f7 100755 --- a/setup.py +++ b/setup.py @@ -1,110 +1,4 @@ #!/usr/bin/env python -import sys +from setuptools import setup -import versioneer -from setuptools import find_packages, setup - -DISTNAME = "xarray" -LICENSE = "Apache" -AUTHOR = "xarray Developers" -AUTHOR_EMAIL = "xarray@googlegroups.com" -URL = "https://github.com/pydata/xarray" -CLASSIFIERS = [ - "Development Status :: 5 - Production/Stable", - "License :: OSI Approved :: Apache Software License", - "Operating System :: OS Independent", - "Intended Audience :: Science/Research", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Topic :: Scientific/Engineering", -] - -PYTHON_REQUIRES = ">=3.6" -INSTALL_REQUIRES = ["numpy >= 1.14", "pandas >= 0.24"] -needs_pytest = {"pytest", "test", "ptr"}.intersection(sys.argv) -SETUP_REQUIRES = ["pytest-runner >= 4.2"] if needs_pytest else [] -TESTS_REQUIRE = ["pytest >= 2.7.1"] - -DESCRIPTION = "N-D labeled arrays and datasets in Python" -LONG_DESCRIPTION = """ -**xarray** (formerly **xray**) is an open source project and Python package -that makes working with labelled multi-dimensional arrays simple, -efficient, and fun! - -Xarray introduces labels in the form of dimensions, coordinates and -attributes on top of raw NumPy_-like arrays, which allows for a more -intuitive, more concise, and less error-prone developer experience. -The package includes a large and growing library of domain-agnostic functions -for advanced analytics and visualization with these data structures. - -Xarray was inspired by and borrows heavily from pandas_, the popular data -analysis package focused on labelled tabular data. -It is particularly tailored to working with netCDF_ files, which were the -source of xarray's data model, and integrates tightly with dask_ for parallel -computing. - -.. _NumPy: https://www.numpy.org -.. _pandas: https://pandas.pydata.org -.. _dask: https://dask.org -.. _netCDF: https://www.unidata.ucar.edu/software/netcdf - -Why xarray? ------------ - -Multi-dimensional (a.k.a. N-dimensional, ND) arrays (sometimes called -"tensors") are an essential part of computational science. -They are encountered in a wide range of fields, including physics, astronomy, -geoscience, bioinformatics, engineering, finance, and deep learning. -In Python, NumPy_ provides the fundamental data structure and API for -working with raw ND arrays. -However, real-world datasets are usually more than just raw numbers; -they have labels which encode information about how the array values map -to locations in space, time, etc. - -Xarray doesn't just keep track of labels on arrays -- it uses them to provide a -powerful and concise interface. For example: - -- Apply operations over dimensions by name: ``x.sum('time')``. -- Select values by label instead of integer location: - ``x.loc['2014-01-01']`` or ``x.sel(time='2014-01-01')``. -- Mathematical operations (e.g., ``x - y``) vectorize across multiple - dimensions (array broadcasting) based on dimension names, not shape. -- Flexible split-apply-combine operations with groupby: - ``x.groupby('time.dayofyear').mean()``. -- Database like alignment based on coordinate labels that smoothly - handles missing values: ``x, y = xr.align(x, y, join='outer')``. -- Keep track of arbitrary metadata in the form of a Python dictionary: - ``x.attrs``. - -Learn more ----------- - -- Documentation: http://xarray.pydata.org -- Issue tracker: http://github.com/pydata/xarray/issues -- Source code: http://github.com/pydata/xarray -- SciPy2015 talk: https://www.youtube.com/watch?v=X0pAhJgySxk -""" - - -setup( - name=DISTNAME, - version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), - license=LICENSE, - author=AUTHOR, - author_email=AUTHOR_EMAIL, - classifiers=CLASSIFIERS, - description=DESCRIPTION, - long_description=LONG_DESCRIPTION, - python_requires=PYTHON_REQUIRES, - install_requires=INSTALL_REQUIRES, - setup_requires=SETUP_REQUIRES, - tests_require=TESTS_REQUIRE, - url=URL, - packages=find_packages(), - package_data={ - "xarray": ["py.typed", "tests/data/*", "static/css/*", "static/html/*"] - }, -) +setup(use_scm_version=True) diff --git a/versioneer.py b/versioneer.py deleted file mode 100644 index 7b55d5d06cb..00000000000 --- a/versioneer.py +++ /dev/null @@ -1,1883 +0,0 @@ -# flake8: noqa - -# Version: 0.18 - -"""The Versioneer - like a rocketeer, but for versions. - -The Versioneer -============== - -* like a rocketeer, but for versions! -* https://github.com/warner/python-versioneer -* Brian Warner -* License: Public Domain -* Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, 3.5, 3.6, and pypy -* [![Latest Version] -(https://pypip.in/version/versioneer/badge.svg?style=flat) -](https://pypi.python.org/pypi/versioneer/) -* [![Build Status] -(https://travis-ci.org/warner/python-versioneer.png?branch=master) -](https://travis-ci.org/warner/python-versioneer) - -This is a tool for managing a recorded version number in distutils-based -python projects. The goal is to remove the tedious and error-prone "update -the embedded version string" step from your release process. Making a new -release should be as easy as recording a new tag in your version-control -system, and maybe making new tarballs. - - -## Quick Install - -* `pip install versioneer` to somewhere to your $PATH -* add a `[versioneer]` section to your setup.cfg (see below) -* run `versioneer install` in your source tree, commit the results - -## Version Identifiers - -Source trees come from a variety of places: - -* a version-control system checkout (mostly used by developers) -* a nightly tarball, produced by build automation -* a snapshot tarball, produced by a web-based VCS browser, like github's - "tarball from tag" feature -* a release tarball, produced by "setup.py sdist", distributed through PyPI - -Within each source tree, the version identifier (either a string or a number, -this tool is format-agnostic) can come from a variety of places: - -* ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows - about recent "tags" and an absolute revision-id -* the name of the directory into which the tarball was unpacked -* an expanded VCS keyword ($Id$, etc) -* a `_version.py` created by some earlier build step - -For released software, the version identifier is closely related to a VCS -tag. Some projects use tag names that include more than just the version -string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool -needs to strip the tag prefix to extract the version identifier. For -unreleased software (between tags), the version identifier should provide -enough information to help developers recreate the same tree, while also -giving them an idea of roughly how old the tree is (after version 1.2, before -version 1.3). Many VCS systems can report a description that captures this, -for example `git describe --tags --dirty --always` reports things like -"0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the -0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has -uncommitted changes. - -The version identifier is used for multiple purposes: - -* to allow the module to self-identify its version: `myproject.__version__` -* to choose a name and prefix for a 'setup.py sdist' tarball - -## Theory of Operation - -Versioneer works by adding a special `_version.py` file into your source -tree, where your `__init__.py` can import it. This `_version.py` knows how to -dynamically ask the VCS tool for version information at import time. - -`_version.py` also contains `$Revision$` markers, and the installation -process marks `_version.py` to have this marker rewritten with a tag name -during the `git archive` command. As a result, generated tarballs will -contain enough information to get the proper version. - -To allow `setup.py` to compute a version too, a `versioneer.py` is added to -the top level of your source tree, next to `setup.py` and the `setup.cfg` -that configures it. This overrides several distutils/setuptools commands to -compute the version when invoked, and changes `setup.py build` and `setup.py -sdist` to replace `_version.py` with a small static file that contains just -the generated version data. - -## Installation - -See [INSTALL.md](./INSTALL.md) for detailed installation instructions. - -## Version-String Flavors - -Code which uses Versioneer can learn about its version string at runtime by -importing `_version` from your main `__init__.py` file and running the -`get_versions()` function. From the "outside" (e.g. in `setup.py`), you can -import the top-level `versioneer.py` and run `get_versions()`. - -Both functions return a dictionary with different flavors of version -information: - -* `['version']`: A condensed version string, rendered using the selected - style. This is the most commonly used value for the project's version - string. The default "pep440" style yields strings like `0.11`, - `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section - below for alternative styles. - -* `['full-revisionid']`: detailed revision identifier. For Git, this is the - full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". - -* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the - commit date in ISO 8601 format. This will be None if the date is not - available. - -* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that - this is only accurate if run in a VCS checkout, otherwise it is likely to - be False or None - -* `['error']`: if the version string could not be computed, this will be set - to a string describing the problem, otherwise it will be None. It may be - useful to throw an exception in setup.py if this is set, to avoid e.g. - creating tarballs with a version string of "unknown". - -Some variants are more useful than others. Including `full-revisionid` in a -bug report should allow developers to reconstruct the exact code being tested -(or indicate the presence of local changes that should be shared with the -developers). `version` is suitable for display in an "about" box or a CLI -`--version` output: it can be easily compared against release notes and lists -of bugs fixed in various releases. - -The installer adds the following text to your `__init__.py` to place a basic -version in `YOURPROJECT.__version__`: - - from ._version import get_versions - __version__ = get_versions()['version'] - del get_versions - -## Styles - -The setup.cfg `style=` configuration controls how the VCS information is -rendered into a version string. - -The default style, "pep440", produces a PEP440-compliant string, equal to the -un-prefixed tag name for actual releases, and containing an additional "local -version" section with more detail for in-between builds. For Git, this is -TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags ---dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the -tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and -that this commit is two revisions ("+2") beyond the "0.11" tag. For released -software (exactly equal to a known tag), the identifier will only contain the -stripped tag, e.g. "0.11". - -Other styles are available. See [details.md](details.md) in the Versioneer -source tree for descriptions. - -## Debugging - -Versioneer tries to avoid fatal errors: if something goes wrong, it will tend -to return a version of "0+unknown". To investigate the problem, run `setup.py -version`, which will run the version-lookup code in a verbose mode, and will -display the full contents of `get_versions()` (including the `error` string, -which may help identify what went wrong). - -## Known Limitations - -Some situations are known to cause problems for Versioneer. This details the -most significant ones. More can be found on Github -[issues page](https://github.com/warner/python-versioneer/issues). - -### Subprojects - -Versioneer has limited support for source trees in which `setup.py` is not in -the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are -two common reasons why `setup.py` might not be in the root: - -* Source trees which contain multiple subprojects, such as - [Buildbot](https://github.com/buildbot/buildbot), which contains both - "master" and "slave" subprojects, each with their own `setup.py`, - `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI - distributions (and upload multiple independently-installable tarballs). -* Source trees whose main purpose is to contain a C library, but which also - provide bindings to Python (and perhaps other langauges) in subdirectories. - -Versioneer will look for `.git` in parent directories, and most operations -should get the right version string. However `pip` and `setuptools` have bugs -and implementation details which frequently cause `pip install .` from a -subproject directory to fail to find a correct version string (so it usually -defaults to `0+unknown`). - -`pip install --editable .` should work correctly. `setup.py install` might -work too. - -Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in -some later version. - -[Bug #38](https://github.com/warner/python-versioneer/issues/38) is tracking -this issue. The discussion in -[PR #61](https://github.com/warner/python-versioneer/pull/61) describes the -issue from the Versioneer side in more detail. -[pip PR#3176](https://github.com/pypa/pip/pull/3176) and -[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve -pip to let Versioneer work correctly. - -Versioneer-0.16 and earlier only looked for a `.git` directory next to the -`setup.cfg`, so subprojects were completely unsupported with those releases. - -### Editable installs with setuptools <= 18.5 - -`setup.py develop` and `pip install --editable .` allow you to install a -project into a virtualenv once, then continue editing the source code (and -test) without re-installing after every change. - -"Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a -convenient way to specify executable scripts that should be installed along -with the python package. - -These both work as expected when using modern setuptools. When using -setuptools-18.5 or earlier, however, certain operations will cause -`pkg_resources.DistributionNotFound` errors when running the entrypoint -script, which must be resolved by re-installing the package. This happens -when the install happens with one version, then the egg_info data is -regenerated while a different version is checked out. Many setup.py commands -cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into -a different virtualenv), so this can be surprising. - -[Bug #83](https://github.com/warner/python-versioneer/issues/83) describes -this one, but upgrading to a newer version of setuptools should probably -resolve it. - -### Unicode version strings - -While Versioneer works (and is continually tested) with both Python 2 and -Python 3, it is not entirely consistent with bytes-vs-unicode distinctions. -Newer releases probably generate unicode version strings on py2. It's not -clear that this is wrong, but it may be surprising for applications when then -write these strings to a network connection or include them in bytes-oriented -APIs like cryptographic checksums. - -[Bug #71](https://github.com/warner/python-versioneer/issues/71) investigates -this question. - - -## Updating Versioneer - -To upgrade your project to a new release of Versioneer, do the following: - -* install the new Versioneer (`pip install -U versioneer` or equivalent) -* edit `setup.cfg`, if necessary, to include any new configuration settings - indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. -* re-run `versioneer install` in your source tree, to replace - `SRC/_version.py` -* commit any changed files - -## Future Directions - -This tool is designed to make it easily extended to other version-control -systems: all VCS-specific components are in separate directories like -src/git/ . The top-level `versioneer.py` script is assembled from these -components by running make-versioneer.py . In the future, make-versioneer.py -will take a VCS name as an argument, and will construct a version of -`versioneer.py` that is specific to the given VCS. It might also take the -configuration arguments that are currently provided manually during -installation by editing setup.py . Alternatively, it might go the other -direction and include code from all supported VCS systems, reducing the -number of intermediate scripts. - - -## License - -To make Versioneer easier to embed, all its code is dedicated to the public -domain. The `_version.py` that it creates is also in the public domain. -Specifically, both are released under the Creative Commons "Public Domain -Dedication" license (CC0-1.0), as described in -https://creativecommons.org/publicdomain/zero/1.0/ . - -""" - - -import errno -import json -import os -import re -import subprocess -import sys - -try: - import configparser -except ImportError: - import ConfigParser as configparser - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_root(): - """Get the project root directory. - - We require that all commands are run from the project root, i.e. the - directory that contains setup.py, setup.cfg, and versioneer.py . - """ - root = os.path.realpath(os.path.abspath(os.getcwd())) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - # allow 'python path/to/setup.py COMMAND' - root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ( - "Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND')." - ) - raise VersioneerBadRootError(err) - try: - # Certain runtime workflows (setup.py install/develop in a setuptools - # tree) execute all dependencies in a single python process, so - # "versioneer" may be imported multiple times, and python's shared - # module-import table will cache the first one. So we can't use - # os.path.dirname(__file__), as that will find whichever - # versioneer.py was first imported, even in later projects. - me = os.path.realpath(os.path.abspath(__file__)) - me_dir = os.path.normcase(os.path.splitext(me)[0]) - vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) - if me_dir != vsr_dir: - print( - "Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(me), versioneer_py) - ) - except NameError: - pass - return root - - -def get_config_from_root(root): - """Read the project setup.cfg file to determine Versioneer config.""" - # This might raise EnvironmentError (if setup.cfg is missing), or - # configparser.NoSectionError (if it lacks a [versioneer] section), or - # configparser.NoOptionError (if it lacks "VCS="). See the docstring at - # the top of versioneer.py for instructions on writing your setup.cfg . - setup_cfg = os.path.join(root, "setup.cfg") - parser = configparser.SafeConfigParser() - with open(setup_cfg, "r") as f: - parser.readfp(f) - VCS = parser.get("versioneer", "VCS") # mandatory - - def get(parser, name): - if parser.has_option("versioneer", name): - return parser.get("versioneer", name) - return None - - cfg = VersioneerConfig() - cfg.VCS = VCS - cfg.style = get(parser, "style") or "" - cfg.versionfile_source = get(parser, "versionfile_source") - cfg.versionfile_build = get(parser, "versionfile_build") - cfg.tag_prefix = get(parser, "tag_prefix") - if cfg.tag_prefix in ("''", '""'): - cfg.tag_prefix = "" - cfg.parentdir_prefix = get(parser, "parentdir_prefix") - cfg.verbose = get(parser, "verbose") - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -# these dictionaries contain VCS-specific tools -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen( - [c] + args, - cwd=cwd, - env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr else None), - ) - break - except OSError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode - - -LONG_VERSION_PY[ - "git" -] = r''' -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" - git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" - git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "%(STYLE)s" - cfg.tag_prefix = "%(TAG_PREFIX)s" - cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" - cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) - break - except EnvironmentError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %%s" %% dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %%s" %% (commands,)) - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %%s (error)" %% dispcmd) - print("stdout was %%s" %% stdout) - return None, p.returncode - return stdout, p.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %%s but none started with prefix %%s" %% - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %%d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) - if verbose: - print("discarding '%%s', no digits" %% ",".join(refs - tags)) - if verbose: - print("likely tags: %%s" %% ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - if verbose: - print("picking %%s" %% r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %%s not under git control" %% root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%%s*" %% tag_prefix], - cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%%s'" - %% describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%%s' doesn't start with prefix '%%s'" - print(fmt %% (full_tag, tag_prefix)) - pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" - %% (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%%ci", "HEAD"], - cwd=root)[0].strip() - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%%d" %% pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%%d" %% pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%%s" %% pieces["short"] - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%%s" %% pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%%s'" %% style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} -''' - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except OSError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r"\d", r)} - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix) :] - if verbose: - print("picking %s" % r) - return { - "version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": None, - "date": date, - } - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return { - "version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": "no suitable tags", - "date": None, - } - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command( - GITS, - [ - "describe", - "--tags", - "--dirty", - "--always", - "--long", - "--match", - "%s*" % tag_prefix, - ], - cwd=root, - ) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[: git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) - if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( - full_tag, - tag_prefix, - ) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix) :] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ - 0 - ].strip() - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def do_vcs_install(manifest_in, versionfile_source, ipy): - """Git-specific installation logic for Versioneer. - - For Git, this means creating/changing .gitattributes to mark _version.py - for export-subst keyword substitution. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - files = [manifest_in, versionfile_source] - if ipy: - files.append(ipy) - try: - me = __file__ - if me.endswith(".pyc") or me.endswith(".pyo"): - me = os.path.splitext(me)[0] + ".py" - versioneer_file = os.path.relpath(me) - except NameError: - versioneer_file = "versioneer.py" - files.append(versioneer_file) - present = False - try: - f = open(".gitattributes", "r") - for line in f.readlines(): - if line.strip().startswith(versionfile_source): - if "export-subst" in line.strip().split()[1:]: - present = True - f.close() - except OSError: - pass - if not present: - f = open(".gitattributes", "a+") - f.write("%s export-subst\n" % versionfile_source) - f.close() - files.append(".gitattributes") - run_command(GITS, ["add", "--"] + files) - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return { - "version": dirname[len(parentdir_prefix) :], - "full-revisionid": None, - "dirty": False, - "error": None, - "date": None, - } - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print( - "Tried directories %s but none started with prefix %s" - % (str(rootdirs), parentdir_prefix) - ) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -SHORT_VERSION_PY = """ -# This file was generated by 'versioneer.py' (0.18) from -# revision-control system data, or from the parent directory name of an -# unpacked source archive. Distribution tarballs contain a pre-generated copy -# of this file. - -import json - -version_json = ''' -%s -''' # END VERSION_JSON - - -def get_versions(): - return json.loads(version_json) -""" - - -def versions_from_file(filename): - """Try to determine the version from _version.py if present.""" - try: - with open(filename) as f: - contents = f.read() - except OSError: - raise NotThisMethod("unable to read _version.py") - mo = re.search( - r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S - ) - if not mo: - mo = re.search( - r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S - ) - if not mo: - raise NotThisMethod("no version_json in _version.py") - return json.loads(mo.group(1)) - - -def write_to_version_file(filename, versions): - """Write the given version number to the given _version.py file.""" - os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) - with open(filename, "w") as f: - f.write(SHORT_VERSION_PY % contents) - - print("set %s to '%s'" % (filename, versions["version"])) - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return { - "version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None, - } - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return { - "version": rendered, - "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], - "error": None, - "date": pieces.get("date"), - } - - -class VersioneerBadRootError(Exception): - """The project root directory is unknown or missing key files.""" - - -def get_versions(verbose=False): - """Get the project version from whatever source is available. - - Returns dict with two keys: 'version' and 'full'. - """ - if "versioneer" in sys.modules: - # see the discussion in cmdclass.py:get_cmdclass() - del sys.modules["versioneer"] - - root = get_root() - cfg = get_config_from_root(root) - - assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" - handlers = HANDLERS.get(cfg.VCS) - assert handlers, "unrecognized VCS '%s'" % cfg.VCS - verbose = verbose or cfg.verbose - assert ( - cfg.versionfile_source is not None - ), "please set versioneer.versionfile_source" - assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" - - versionfile_abs = os.path.join(root, cfg.versionfile_source) - - # extract version from first of: _version.py, VCS command (e.g. 'git - # describe'), parentdir. This is meant to work for developers using a - # source checkout, for users of a tarball created by 'setup.py sdist', - # and for users of a tarball/zipball created by 'git archive' or github's - # download-from-tag feature or the equivalent in other VCSes. - - get_keywords_f = handlers.get("get_keywords") - from_keywords_f = handlers.get("keywords") - if get_keywords_f and from_keywords_f: - try: - keywords = get_keywords_f(versionfile_abs) - ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) - if verbose: - print("got version from expanded keyword %s" % ver) - return ver - except NotThisMethod: - pass - - try: - ver = versions_from_file(versionfile_abs) - if verbose: - print("got version from file %s %s" % (versionfile_abs, ver)) - return ver - except NotThisMethod: - pass - - from_vcs_f = handlers.get("pieces_from_vcs") - if from_vcs_f: - try: - pieces = from_vcs_f(cfg.tag_prefix, root, verbose) - ver = render(pieces, cfg.style) - if verbose: - print("got version from VCS %s" % ver) - return ver - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - if verbose: - print("got version from parentdir %s" % ver) - return ver - except NotThisMethod: - pass - - if verbose: - print("unable to compute version") - - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", - "date": None, - } - - -def get_version(): - """Get the short version string for this project.""" - return get_versions()["version"] - - -def get_cmdclass(): - """Get the custom setuptools/distutils subclasses used by Versioneer.""" - if "versioneer" in sys.modules: - del sys.modules["versioneer"] - # this fixes the "python setup.py develop" case (also 'install' and - # 'easy_install .'), in which subdependencies of the main project are - # built (using setup.py bdist_egg) in the same python process. Assume - # a main project A and a dependency B, which use different versions - # of Versioneer. A's setup.py imports A's Versioneer, leaving it in - # sys.modules by the time B's setup.py is executed, causing B to run - # with the wrong versioneer. Setuptools wraps the sub-dep builds in a - # sandbox that restores sys.modules to it's pre-build state, so the - # parent is protected against the child's "import versioneer". By - # removing ourselves from sys.modules here, before the child build - # happens, we protect the child from the parent's versioneer too. - # Also see https://github.com/warner/python-versioneer/issues/52 - - cmds = {} - - # we add "version" to both distutils and setuptools - from distutils.core import Command - - class cmd_version(Command): - description = "report generated version string" - user_options = [] - boolean_options = [] - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - vers = get_versions(verbose=True) - print("Version: %s" % vers["version"]) - print(" full-revisionid: %s" % vers.get("full-revisionid")) - print(" dirty: %s" % vers.get("dirty")) - print(" date: %s" % vers.get("date")) - if vers["error"]: - print(" error: %s" % vers["error"]) - - cmds["version"] = cmd_version - - # we override "build_py" in both distutils and setuptools - # - # most invocation pathways end up running build_py: - # distutils/build -> build_py - # distutils/install -> distutils/build ->.. - # setuptools/bdist_wheel -> distutils/install ->.. - # setuptools/bdist_egg -> distutils/install_lib -> build_py - # setuptools/install -> bdist_egg ->.. - # setuptools/develop -> ? - # pip install: - # copies source tree to a tempdir before running egg_info/etc - # if .git isn't copied too, 'git describe' will fail - # then does setup.py bdist_wheel, or sometimes setup.py install - # setup.py egg_info -> ? - - # we override different "build_py" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.build_py import build_py as _build_py - else: - from distutils.command.build_py import build_py as _build_py - - class cmd_build_py(_build_py): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - _build_py.run(self) - # now locate _version.py in the new build/ directory and replace - # it with an updated value - if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - cmds["build_py"] = cmd_build_py - - if "cx_Freeze" in sys.modules: # cx_freeze enabled? - from cx_Freeze.dist import build_exe as _build_exe - - # nczeczulin reports that py2exe won't like the pep440-style string - # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. - # setup(console=[{ - # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION - # "product_version": versioneer.get_version(), - # ... - - class cmd_build_exe(_build_exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _build_exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - - cmds["build_exe"] = cmd_build_exe - del cmds["build_py"] - - if "py2exe" in sys.modules: # py2exe enabled? - try: - from py2exe.distutils_buildexe import py2exe as _py2exe # py3 - except ImportError: - from py2exe.build_exe import py2exe as _py2exe # py2 - - class cmd_py2exe(_py2exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _py2exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - - cmds["py2exe"] = cmd_py2exe - - # we override different "sdist" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.sdist import sdist as _sdist - else: - from distutils.command.sdist import sdist as _sdist - - class cmd_sdist(_sdist): - def run(self): - versions = get_versions() - self._versioneer_generated_versions = versions - # unless we update this, the command will keep using the old - # version - self.distribution.metadata.version = versions["version"] - return _sdist.run(self) - - def make_release_tree(self, base_dir, files): - root = get_root() - cfg = get_config_from_root(root) - _sdist.make_release_tree(self, base_dir, files) - # now locate _version.py in the new base_dir directory - # (remembering that it may be a hardlink) and replace it with an - # updated value - target_versionfile = os.path.join(base_dir, cfg.versionfile_source) - print("UPDATING %s" % target_versionfile) - write_to_version_file( - target_versionfile, self._versioneer_generated_versions - ) - - cmds["sdist"] = cmd_sdist - - return cmds - - -CONFIG_ERROR = """ -setup.cfg is missing the necessary Versioneer configuration. You need -a section like: - - [versioneer] - VCS = git - style = pep440 - versionfile_source = src/myproject/_version.py - versionfile_build = myproject/_version.py - tag_prefix = - parentdir_prefix = myproject- - -You will also need to edit your setup.py to use the results: - - import versioneer - setup(version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), ...) - -Please read the docstring in ./versioneer.py for configuration instructions, -edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. -""" - -SAMPLE_CONFIG = """ -# See the docstring in versioneer.py for instructions. Note that you must -# re-run 'versioneer.py setup' after changing this section, and commit the -# resulting files. - -[versioneer] -#VCS = git -#style = pep440 -#versionfile_source = -#versionfile_build = -#tag_prefix = -#parentdir_prefix = - -""" - -INIT_PY_SNIPPET = """ -from ._version import get_versions -__version__ = get_versions()['version'] -del get_versions -""" - - -def do_setup(): - """Main VCS-independent setup function for installing Versioneer.""" - root = get_root() - try: - cfg = get_config_from_root(root) - except (OSError, configparser.NoSectionError, configparser.NoOptionError) as e: - if isinstance(e, (EnvironmentError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", file=sys.stderr) - with open(os.path.join(root, "setup.cfg"), "a") as f: - f.write(SAMPLE_CONFIG) - print(CONFIG_ERROR, file=sys.stderr) - return 1 - - print(" creating %s" % cfg.versionfile_source) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") - if os.path.exists(ipy): - try: - with open(ipy, "r") as f: - old = f.read() - except OSError: - old = "" - if INIT_PY_SNIPPET not in old: - print(" appending to %s" % ipy) - with open(ipy, "a") as f: - f.write(INIT_PY_SNIPPET) - else: - print(" %s unmodified" % ipy) - else: - print(" %s doesn't exist, ok" % ipy) - ipy = None - - # Make sure both the top-level "versioneer.py" and versionfile_source - # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so - # they'll be copied into source distributions. Pip won't be able to - # install the package without this. - manifest_in = os.path.join(root, "MANIFEST.in") - simple_includes = set() - try: - with open(manifest_in, "r") as f: - for line in f: - if line.startswith("include "): - for include in line.split()[1:]: - simple_includes.add(include) - except OSError: - pass - # That doesn't cover everything MANIFEST.in can do - # (http://docs.python.org/2/distutils/sourcedist.html#commands), so - # it might give some false negatives. Appending redundant 'include' - # lines is safe, though. - if "versioneer.py" not in simple_includes: - print(" appending 'versioneer.py' to MANIFEST.in") - with open(manifest_in, "a") as f: - f.write("include versioneer.py\n") - else: - print(" 'versioneer.py' already in MANIFEST.in") - if cfg.versionfile_source not in simple_includes: - print( - " appending versionfile_source ('%s') to MANIFEST.in" - % cfg.versionfile_source - ) - with open(manifest_in, "a") as f: - f.write("include %s\n" % cfg.versionfile_source) - else: - print(" versionfile_source already in MANIFEST.in") - - # Make VCS-specific changes. For git, this means creating/changing - # .gitattributes to mark _version.py for export-subst keyword - # substitution. - do_vcs_install(manifest_in, cfg.versionfile_source, ipy) - return 0 - - -def scan_setup_py(): - """Validate the contents of setup.py against Versioneer's expectations.""" - found = set() - setters = False - errors = 0 - with open("setup.py", "r") as f: - for line in f.readlines(): - if "import versioneer" in line: - found.add("import") - if "versioneer.get_cmdclass()" in line: - found.add("cmdclass") - if "versioneer.get_version()" in line: - found.add("get_version") - if "versioneer.VCS" in line: - setters = True - if "versioneer.versionfile_source" in line: - setters = True - if len(found) != 3: - print("") - print("Your setup.py appears to be missing some important items") - print("(but I might be wrong). Please make sure it has something") - print("roughly like the following:") - print("") - print(" import versioneer") - print(" setup( version=versioneer.get_version(),") - print(" cmdclass=versioneer.get_cmdclass(), ...)") - print("") - errors += 1 - if setters: - print("You should remove lines like 'versioneer.VCS = ' and") - print("'versioneer.versionfile_source = ' . This configuration") - print("now lives in setup.cfg, and should be removed from setup.py") - print("") - errors += 1 - return errors - - -if __name__ == "__main__": - cmd = sys.argv[1] - if cmd == "setup": - errors = do_setup() - errors += scan_setup_py() - if errors: - sys.exit(1) diff --git a/xarray/__init__.py b/xarray/__init__.py index 0b44c293b3f..0388eaf3dab 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -1,9 +1,12 @@ """ isort:skip_file """ -from ._version import get_versions +import pkg_resources -__version__ = get_versions()["version"] -del get_versions +try: + __version__ = pkg_resources.get_distribution("xarray").version +except Exception: + # Local copy, not installed with setuptools + __version__ = "unknown" from .core.alignment import align, broadcast from .core.common import full_like, zeros_like, ones_like diff --git a/xarray/_version.py b/xarray/_version.py deleted file mode 100644 index 0ccb33a5e56..00000000000 --- a/xarray/_version.py +++ /dev/null @@ -1,555 +0,0 @@ -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "$Format:%d$" - git_full = "$Format:%H$" - git_date = "$Format:%ci$" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "pep440" - cfg.tag_prefix = "v" - cfg.parentdir_prefix = "xarray-" - cfg.versionfile_source = "xarray/_version.py" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen( - [c] + args, - cwd=cwd, - env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr else None), - ) - break - except OSError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print(f"unable to find command, tried {commands}") - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return { - "version": dirname[len(parentdir_prefix) :], - "full-revisionid": None, - "dirty": False, - "error": None, - "date": None, - } - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print( - "Tried directories %s but none started with prefix %s" - % (str(rootdirs), parentdir_prefix) - ) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except OSError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r"\d", r)} - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix) :] - if verbose: - print("picking %s" % r) - return { - "version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": None, - "date": date, - } - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return { - "version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": "no suitable tags", - "date": None, - } - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command( - GITS, - [ - "describe", - "--tags", - "--dirty", - "--always", - "--long", - "--match", - "%s*" % tag_prefix, - ], - cwd=root, - ) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[: git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) - if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = "tag '{}' doesn't start with prefix '{}'".format( - full_tag, tag_prefix - ) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix) :] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ - 0 - ].strip() - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return { - "version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None, - } - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return { - "version": rendered, - "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], - "error": None, - "date": pieces.get("date"), - } - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for i in cfg.versionfile_source.split("/"): - root = os.path.dirname(root) - except NameError: - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None, - } - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", - "date": None, - } From 17b70caa6eafa062fd31e7f39334b3de922ff422 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 22 Jan 2020 15:43:18 +0000 Subject: [PATCH 25/29] apply_ufunc: Add meta kwarg + bump dask to 2.2 (#3660) * apply_func: Set meta=np.ndarray when vectorize=True and dask="parallelized" Closes #3574 * Add meta kwarg to apply_ufunc. * Bump minimum dask to 2.1.0 * Update distributed too * bump minimum dask, distributed to 2.2 * Update whats-new * minor. * fix whats-new * Attempt numpy=1.15 * Revert "Attempt numpy=1.15" This reverts commit 2b22470b799c2dba14ac834a623d9d011104d3ce. * xfail test. * More xfailed tests. * Update xfail reason. * fix whats-new * Add test to ensure meta is passed on to dask. * Use skipif instead of xfail. --- ci/requirements/py36-min-all-deps.yml | 4 ++-- doc/whats-new.rst | 7 ++++++- xarray/core/computation.py | 23 ++++++++++++++++++++++- xarray/tests/test_backends.py | 12 ++++++++++-- xarray/tests/test_computation.py | 18 ++++++++++++++++++ xarray/tests/test_sparse.py | 13 +++++++++++++ 6 files changed, 71 insertions(+), 6 deletions(-) diff --git a/ci/requirements/py36-min-all-deps.yml b/ci/requirements/py36-min-all-deps.yml index e7756172311..ea707461013 100644 --- a/ci/requirements/py36-min-all-deps.yml +++ b/ci/requirements/py36-min-all-deps.yml @@ -15,8 +15,8 @@ dependencies: - cfgrib=0.9 - cftime=1.0 - coveralls - - dask=1.2 - - distributed=1.27 + - dask=2.2 + - distributed=2.2 - flake8 - h5netcdf=0.7 - h5py=2.9 # Policy allows for 2.10, but it's a conflict-fest diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 67cb8e7125a..22088357e0b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,7 @@ v0.15.0 (unreleased) Breaking changes ~~~~~~~~~~~~~~~~ +- Bumped minimum ``dask`` version to 2.2. - Remove ``compat`` and ``encoding`` kwargs from ``DataArray``, which have been deprecated since 0.12. (:pull:`3650`). Instead, specify the encoding when writing to disk or set @@ -50,6 +51,8 @@ New Features - Added the ``count`` reduction method to both :py:class:`~core.rolling.DatasetCoarsen` and :py:class:`~core.rolling.DataArrayCoarsen` objects. (:pull:`3500`) By `Deepak Cherian `_ +- Add ``meta`` kwarg to :py:func:`~xarray.apply_ufunc`; this is passed on to + :py:meth:`dask.array.blockwise`. (:pull:`3660`) By `Deepak Cherian `_. - Add `attrs_file` option in :py:func:`~xarray.open_mfdataset` to choose the source file for global attributes in a multi-file dataset (:issue:`2382`, :pull:`3498`) by `Julien Seguinot _`. @@ -63,7 +66,9 @@ New Features Bug fixes ~~~~~~~~~ - +- Applying a user-defined function that adds new dimensions using :py:func:`apply_ufunc` + and ``vectorize=True`` now works with ``dask > 2.0``. (:issue:`3574`, :pull:`3660`). + By `Deepak Cherian `_. - Fix :py:meth:`xarray.combine_by_coords` to allow for combining incomplete hypercubes of Datasets (:issue:`3648`). By `Ian Bolliger `_. diff --git a/xarray/core/computation.py b/xarray/core/computation.py index eb9ca8c17fc..d2c5c32bc00 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -548,6 +548,7 @@ def apply_variable_ufunc( output_dtypes=None, output_sizes=None, keep_attrs=False, + meta=None, ): """Apply a ndarray level function over Variable and/or ndarray objects. """ @@ -590,6 +591,7 @@ def func(*arrays): signature, output_dtypes, output_sizes, + meta, ) elif dask == "allowed": @@ -648,7 +650,14 @@ def func(*arrays): def _apply_blockwise( - func, args, input_dims, output_dims, signature, output_dtypes, output_sizes=None + func, + args, + input_dims, + output_dims, + signature, + output_dtypes, + output_sizes=None, + meta=None, ): import dask.array @@ -720,6 +729,7 @@ def _apply_blockwise( dtype=dtype, concatenate=True, new_axes=output_sizes, + meta=meta, ) @@ -761,6 +771,7 @@ def apply_ufunc( dask: str = "forbidden", output_dtypes: Sequence = None, output_sizes: Mapping[Any, int] = None, + meta: Any = None, ) -> Any: """Apply a vectorized function for unlabeled arrays on xarray objects. @@ -857,6 +868,9 @@ def apply_ufunc( Optional mapping from dimension names to sizes for outputs. Only used if dask='parallelized' and new dimensions (not found on inputs) appear on outputs. + meta : optional + Size-0 object representing the type of array wrapped by dask array. Passed on to + ``dask.array.blockwise``. Returns ------- @@ -990,6 +1004,11 @@ def earth_mover_distance(first_samples, func = functools.partial(func, **kwargs) if vectorize: + if meta is None: + # set meta=np.ndarray by default for numpy vectorized functions + # work around dask bug computing meta with vectorized functions: GH5642 + meta = np.ndarray + if signature.all_core_dims: func = np.vectorize( func, otypes=output_dtypes, signature=signature.to_gufunc_string() @@ -1006,6 +1025,7 @@ def earth_mover_distance(first_samples, dask=dask, output_dtypes=output_dtypes, output_sizes=output_sizes, + meta=meta, ) if any(isinstance(a, GroupBy) for a in args): @@ -1020,6 +1040,7 @@ def earth_mover_distance(first_samples, dataset_fill_value=dataset_fill_value, keep_attrs=keep_attrs, dask=dask, + meta=meta, ) return apply_groupby_func(this_apply, *args) elif any(is_dict_like(a) for a in args): diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 0436ae9d244..bb77cbb94fe 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -37,7 +37,7 @@ from xarray.core import indexing from xarray.core.options import set_options from xarray.core.pycompat import dask_array_type -from xarray.tests import mock +from xarray.tests import LooseVersion, mock from . import ( arm_xfail, @@ -76,9 +76,14 @@ pass try: + import dask import dask.array as da + + dask_version = dask.__version__ except ImportError: - pass + # needed for xfailed tests when dask < 2.4.0 + # remove when min dask > 2.4.0 + dask_version = "10.0" ON_WINDOWS = sys.platform == "win32" @@ -1723,6 +1728,7 @@ def test_hidden_zarr_keys(self): with xr.decode_cf(store): pass + @pytest.mark.skipif(LooseVersion(dask_version) < "2.4", reason="dask GH5334") def test_write_persistence_modes(self): original = create_test_data() @@ -1787,6 +1793,7 @@ def test_encoding_kwarg_fixed_width_string(self): def test_dataset_caching(self): super().test_dataset_caching() + @pytest.mark.skipif(LooseVersion(dask_version) < "2.4", reason="dask GH5334") def test_append_write(self): ds, ds_to_append, _ = create_append_test_data() with self.create_zarr_target() as store_target: @@ -1863,6 +1870,7 @@ def test_check_encoding_is_consistent_after_append(self): xr.concat([ds, ds_to_append], dim="time"), ) + @pytest.mark.skipif(LooseVersion(dask_version) < "2.4", reason="dask GH5334") def test_append_with_new_variable(self): ds, ds_to_append, ds_with_new_var = create_append_test_data() diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 2d373d12095..369903552ad 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -817,6 +817,24 @@ def test_vectorize_dask(): assert_identical(expected, actual) +@requires_dask +def test_vectorize_dask_new_output_dims(): + # regression test for GH3574 + data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) + func = lambda x: x[np.newaxis, ...] + expected = data_array.expand_dims("z") + actual = apply_ufunc( + func, + data_array.chunk({"x": 1}), + output_core_dims=[["z"]], + vectorize=True, + dask="parallelized", + output_dtypes=[float], + output_sizes={"z": 1}, + ).transpose(*expected.dims) + assert_identical(expected, actual) + + def test_output_wrong_number(): variable = xr.Variable("x", np.arange(10)) diff --git a/xarray/tests/test_sparse.py b/xarray/tests/test_sparse.py index a02fef2faeb..21a212c29b3 100644 --- a/xarray/tests/test_sparse.py +++ b/xarray/tests/test_sparse.py @@ -873,3 +873,16 @@ def test_dask_token(): t5 = dask.base.tokenize(ac + 1) assert t4 != t5 assert isinstance(ac.data._meta, sparse.COO) + + +@requires_dask +def test_apply_ufunc_meta_to_blockwise(): + da = xr.DataArray(np.zeros((2, 3)), dims=["x", "y"]).chunk({"x": 2, "y": 1}) + sparse_meta = sparse.COO.from_numpy(np.zeros((0, 0))) + + # if dask computed meta, it would be np.ndarray + expected = xr.apply_ufunc( + lambda x: x, da, dask="parallelized", output_dtypes=[da.dtype], meta=sparse_meta + ).data._meta + + assert_sparse_equal(expected, sparse_meta) From 6d1434e9b9e9766d6fcafceb05e81734b29994ea Mon Sep 17 00:00:00 2001 From: Julien Seguinot Date: Wed, 22 Jan 2020 22:09:38 +0100 Subject: [PATCH 26/29] Allow binned coordinates on 1D plots y-axis. (#3685) * Allow binned coords on 1D plots y-axis. See #3571. * Add helper to convert intervals for 1d plots. * Add tests for line plots with x/y/x&y intervals. * Add tests for step plots with x/y intervals. * Simplify syntax in 1D intervals plot helper. * Document new handling of y-axis intervals. --- doc/whats-new.rst | 3 +++ xarray/plot/plot.py | 33 ++++++--------------------------- xarray/plot/utils.py | 36 ++++++++++++++++++++++++++++++++++++ xarray/tests/test_plot.py | 29 +++++++++++++++++++++++++++++ 4 files changed, 74 insertions(+), 27 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 22088357e0b..c310d53190d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -102,6 +102,9 @@ Bug fixes By `Justus Magin `_. - Fixed errors emitted by ``mypy --strict`` in modules that import xarray. (:issue:`3695`) by `Guido Imperiale `_. +- Fix plotting of binned coordinates on the y axis in :py:meth:`DataArray.plot` + (line) and :py:meth:`DataArray.plot.step` plots (:issue:`#3571`, + :pull:`3685`) by `Julien Seguinot _`. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index d38c9765352..b4802f6194b 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -17,13 +17,11 @@ _ensure_plottable, _infer_interval_breaks, _infer_xy_labels, - _interval_to_double_bound_points, - _interval_to_mid_points, _process_cmap_cbar_kwargs, _rescale_imshow_rgb, + _resolve_intervals_1dplot, _resolve_intervals_2dplot, _update_axes, - _valid_other_type, get_axis, import_matplotlib_pyplot, label_from_attrs, @@ -296,29 +294,10 @@ def line( ax = get_axis(figsize, size, aspect, ax) xplt, yplt, hueplt, xlabel, ylabel, hue_label = _infer_line_data(darray, x, y, hue) - # Remove pd.Intervals if contained in xplt.values. - if _valid_other_type(xplt.values, [pd.Interval]): - # Is it a step plot? (see matplotlib.Axes.step) - if kwargs.get("linestyle", "").startswith("steps-"): - xplt_val, yplt_val = _interval_to_double_bound_points( - xplt.values, yplt.values - ) - # Remove steps-* to be sure that matplotlib is not confused - kwargs["linestyle"] = ( - kwargs["linestyle"] - .replace("steps-pre", "") - .replace("steps-post", "") - .replace("steps-mid", "") - ) - if kwargs["linestyle"] == "": - del kwargs["linestyle"] - else: - xplt_val = _interval_to_mid_points(xplt.values) - yplt_val = yplt.values - xlabel += "_center" - else: - xplt_val = xplt.values - yplt_val = yplt.values + # Remove pd.Intervals if contained in xplt.values and/or yplt.values. + xplt_val, yplt_val, xlabel, ylabel, kwargs = _resolve_intervals_1dplot( + xplt.values, yplt.values, xlabel, ylabel, kwargs + ) _ensure_plottable(xplt_val, yplt_val) @@ -367,7 +346,7 @@ def step(darray, *args, where="pre", linestyle=None, ls=None, **kwargs): every *x* position, i.e. the interval ``[x[i], x[i+1])`` has the value ``y[i]``. - 'mid': Steps occur half-way between the *x* positions. - Note that this parameter is ignored if the x coordinate consists of + Note that this parameter is ignored if one coordinate consists of :py:func:`pandas.Interval` values, e.g. as a result of :py:func:`xarray.Dataset.groupby_bins`. In this case, the actual boundaries of the interval are used. diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 3b739197fea..c900dfeff3e 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -453,6 +453,42 @@ def _interval_to_double_bound_points(xarray, yarray): return xarray, yarray +def _resolve_intervals_1dplot(xval, yval, xlabel, ylabel, kwargs): + """ + Helper function to replace the values of x and/or y coordinate arrays + containing pd.Interval with their mid-points or - for step plots - double + points which double the length. + """ + + # Is it a step plot? (see matplotlib.Axes.step) + if kwargs.get("linestyle", "").startswith("steps-"): + + # Convert intervals to double points + if _valid_other_type(np.array([xval, yval]), [pd.Interval]): + raise TypeError("Can't step plot intervals against intervals.") + if _valid_other_type(xval, [pd.Interval]): + xval, yval = _interval_to_double_bound_points(xval, yval) + if _valid_other_type(yval, [pd.Interval]): + yval, xval = _interval_to_double_bound_points(yval, xval) + + # Remove steps-* to be sure that matplotlib is not confused + del kwargs["linestyle"] + + # Is it another kind of plot? + else: + + # Convert intervals to mid points and adjust labels + if _valid_other_type(xval, [pd.Interval]): + xval = _interval_to_mid_points(xval) + xlabel += "_center" + if _valid_other_type(yval, [pd.Interval]): + yval = _interval_to_mid_points(yval) + ylabel += "_center" + + # return converted arguments + return xval, yval, xlabel, ylabel, kwargs + + def _resolve_intervals_2dplot(val, func_name): """ Helper function to replace the values of a coordinate array containing diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index a5402d88f3e..71cb119f0d6 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -426,9 +426,25 @@ def test_convenient_facetgrid_4d(self): d.plot(x="x", y="y", col="columns", ax=plt.gca()) def test_coord_with_interval(self): + """Test line plot with intervals.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot() + def test_coord_with_interval_x(self): + """Test line plot with intervals explicitly on x axis.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).plot(x="dim_0_bins") + + def test_coord_with_interval_y(self): + """Test line plot with intervals explicitly on y axis.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).plot(y="dim_0_bins") + + def test_coord_with_interval_xy(self): + """Test line plot with intervals on both x and y axes.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).dim_0_bins.plot() + class TestPlot1D(PlotTestCase): @pytest.fixture(autouse=True) @@ -511,10 +527,23 @@ def test_step(self): self.darray[0, 0].plot.step() def test_coord_with_interval_step(self): + """Test step plot with intervals.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot.step() assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + def test_coord_with_interval_step_x(self): + """Test step plot with intervals explicitly on x axis.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(x="dim_0_bins") + assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + + def test_coord_with_interval_step_y(self): + """Test step plot with intervals explicitly on y axis.""" + bins = [-1, 0, 1, 2] + self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(y="dim_0_bins") + assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) + class TestPlotHistogram(PlotTestCase): @pytest.fixture(autouse=True) From ecd67f457240db5bc3316c092ca0bb3a134b4239 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 24 Jan 2020 15:28:47 +0000 Subject: [PATCH 27/29] setuptools-scm and isort tweaks (#3720) * Don't require setuptools * isort __init__.py * Fix import * black --- setup.cfg | 22 ++++++------ xarray/__init__.py | 58 +++++++++++++------------------- xarray/backends/api.py | 5 ++- xarray/backends/cfgrib_.py | 2 +- xarray/backends/h5netcdf_.py | 2 +- xarray/backends/netCDF4_.py | 3 +- xarray/backends/netcdf3.py | 3 +- xarray/backends/pseudonetcdf_.py | 2 +- xarray/backends/pydap_.py | 2 +- xarray/backends/pynio_.py | 2 +- xarray/backends/rasterio_.py | 2 +- xarray/backends/scipy_.py | 2 +- xarray/backends/zarr.py | 3 +- 13 files changed, 52 insertions(+), 56 deletions(-) diff --git a/setup.cfg b/setup.cfg index 0e113c15306..ec0fd28ef73 100644 --- a/setup.cfg +++ b/setup.cfg @@ -75,7 +75,7 @@ classifiers = [options] packages = xarray -zip_safe = True +zip_safe = False # https://mypy.readthedocs.io/en/latest/installed_packages.html include_package_data = True python_requires = >=3.6 install_requires = @@ -91,8 +91,8 @@ xarray = static/html/* [tool:pytest] -python_files=test_*.py -testpaths=xarray/tests properties +python_files = test_*.py +testpaths = xarray/tests properties # Fixed upstream in https://github.com/pydata/bottleneck/pull/199 filterwarnings = ignore:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning @@ -104,7 +104,7 @@ markers = slow: slow tests [flake8] -ignore= +ignore = # whitespace before ':' - doesn't work well with black E203 E402 @@ -119,13 +119,13 @@ exclude= doc [isort] -default_section=THIRDPARTY -known_first_party=xarray -multi_line_output=3 -include_trailing_comma=True -force_grid_wrap=0 -use_parentheses=True -line_length=88 +default_section = THIRDPARTY +known_first_party = xarray +multi_line_output = 3 +include_trailing_comma = True +force_grid_wrap = 0 +use_parentheses = True +line_length = 88 # Most of the numerical computing stack doesn't have type annotations yet. [mypy-affine.*] diff --git a/xarray/__init__.py b/xarray/__init__.py index 0388eaf3dab..44dc66411c4 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -1,49 +1,39 @@ -""" isort:skip_file """ - -import pkg_resources - -try: - __version__ = pkg_resources.get_distribution("xarray").version -except Exception: - # Local copy, not installed with setuptools - __version__ = "unknown" - -from .core.alignment import align, broadcast -from .core.common import full_like, zeros_like, ones_like -from .core.concat import concat -from .core.combine import combine_by_coords, combine_nested, auto_combine -from .core.computation import apply_ufunc, dot, where -from .core.extensions import register_dataarray_accessor, register_dataset_accessor -from .core.variable import as_variable, Variable, IndexVariable, Coordinate -from .core.dataset import Dataset -from .core.dataarray import DataArray -from .core.merge import merge, MergeError -from .core.options import set_options -from .core.parallel import map_blocks - +from . import testing, tutorial, ufuncs from .backends.api import ( - open_dataset, + load_dataarray, + load_dataset, open_dataarray, + open_dataset, open_mfdataset, save_mfdataset, - load_dataset, - load_dataarray, ) from .backends.rasterio_ import open_rasterio from .backends.zarr import open_zarr - -from .conventions import decode_cf, SerializationWarning - from .coding.cftime_offsets import cftime_range from .coding.cftimeindex import CFTimeIndex - +from .conventions import SerializationWarning, decode_cf +from .core.alignment import align, broadcast +from .core.combine import auto_combine, combine_by_coords, combine_nested +from .core.common import ALL_DIMS, full_like, ones_like, zeros_like +from .core.computation import apply_ufunc, dot, where +from .core.concat import concat +from .core.dataarray import DataArray +from .core.dataset import Dataset +from .core.extensions import register_dataarray_accessor, register_dataset_accessor +from .core.merge import MergeError, merge +from .core.options import set_options +from .core.parallel import map_blocks +from .core.variable import Coordinate, IndexVariable, Variable, as_variable from .util.print_versions import show_versions -from . import tutorial -from . import ufuncs -from . import testing +try: + import pkg_resources -from .core.common import ALL_DIMS + __version__ = pkg_resources.get_distribution("xarray").version +except Exception: + # Local copy, not installed with setuptools, or setuptools is not available. + # Disable minimum version checks on downstream libraries. + __version__ = "999" # A hardcoded __all__ variable is necessary to appease # `mypy --strict` running in projects that import xarray. diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 0990fc46219..56cd0649989 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -18,13 +18,16 @@ import numpy as np -from .. import DataArray, Dataset, auto_combine, backends, coding, conventions +from .. import backends, coding, conventions from ..core import indexing from ..core.combine import ( _infer_concat_order_from_positions, _nested_combine, + auto_combine, combine_by_coords, ) +from ..core.dataarray import DataArray +from ..core.dataset import Dataset from ..core.utils import close_on_error, is_grib_path, is_remote_uri from .common import AbstractDataStore, ArrayWriter from .locks import _get_scheduler diff --git a/xarray/backends/cfgrib_.py b/xarray/backends/cfgrib_.py index 97c9ac1b9b4..bd946df89b2 100644 --- a/xarray/backends/cfgrib_.py +++ b/xarray/backends/cfgrib_.py @@ -1,8 +1,8 @@ import numpy as np -from .. import Variable from ..core import indexing from ..core.utils import Frozen, FrozenDict +from ..core.variable import Variable from .common import AbstractDataStore, BackendArray from .locks import SerializableLock, ensure_lock diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 203213afe82..2b7c2d9057c 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -2,9 +2,9 @@ import numpy as np -from .. import Variable from ..core import indexing from ..core.utils import FrozenDict, is_remote_uri +from ..core.variable import Variable from .common import WritableCFDataStore, find_root_and_group from .file_manager import CachingFileManager, DummyFileManager from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 0e454ec47de..0a917cde4d7 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -4,10 +4,11 @@ import numpy as np -from .. import Variable, coding +from .. import coding from ..coding.variables import pop_to from ..core import indexing from ..core.utils import FrozenDict, is_remote_uri +from ..core.variable import Variable from .common import ( BackendArray, WritableCFDataStore, diff --git a/xarray/backends/netcdf3.py b/xarray/backends/netcdf3.py index d26b6ce2ea9..c9c4baf9b01 100644 --- a/xarray/backends/netcdf3.py +++ b/xarray/backends/netcdf3.py @@ -2,7 +2,8 @@ import numpy as np -from .. import Variable, coding +from .. import coding +from ..core.variable import Variable # Special characters that are permitted in netCDF names except in the # 0th position of the string diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index 63a7c3b0609..17a4eb8f6bf 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -1,8 +1,8 @@ import numpy as np -from .. import Variable from ..core import indexing from ..core.utils import Frozen, FrozenDict +from ..core.variable import Variable from .common import AbstractDataStore, BackendArray from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index 7ef4ec66241..20e943ab561 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -1,9 +1,9 @@ import numpy as np -from .. import Variable from ..core import indexing from ..core.pycompat import integer_types from ..core.utils import Frozen, FrozenDict, is_dict_like +from ..core.variable import Variable from .common import AbstractDataStore, BackendArray, robust_getitem diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index d9e372ceaf9..1c66ff1ee48 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -1,8 +1,8 @@ import numpy as np -from .. import Variable from ..core import indexing from ..core.utils import Frozen, FrozenDict +from ..core.variable import Variable from .common import AbstractDataStore, BackendArray from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, SerializableLock, combine_locks, ensure_lock diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 1f9b9943573..d041e430db9 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -3,8 +3,8 @@ import numpy as np -from .. import DataArray from ..core import indexing +from ..core.dataarray import DataArray from ..core.utils import is_scalar from .common import BackendArray from .file_manager import CachingFileManager diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 3a4787fc67e..9863285d6de 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -2,9 +2,9 @@ import numpy as np -from .. import Variable from ..core.indexing import NumpyIndexingAdapter from ..core.utils import Frozen, FrozenDict +from ..core.variable import Variable from .common import BackendArray, WritableCFDataStore from .file_manager import CachingFileManager, DummyFileManager from .locks import ensure_lock, get_write_lock diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 6d4ebb02a11..763769dac74 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -2,10 +2,11 @@ import numpy as np -from .. import Variable, coding, conventions +from .. import coding, conventions from ..core import indexing from ..core.pycompat import integer_types from ..core.utils import FrozenDict, HiddenKeyDict +from ..core.variable import Variable from .common import AbstractWritableDataStore, BackendArray, _encode_variable_name # need some special secret attributes to tell us the dimensions From 9c7286639136f52aee877f44de8c89d7c8f41068 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 24 Jan 2020 21:46:53 +0000 Subject: [PATCH 28/29] bump min deps for 0.15 (#3713) * Bump numpy to 1.15, pandas to 0.25, scipy to 1.2 * Remove some backcompat code * bump pandas in doc.yml * Fix repr test * review cleanups. * scipy to 1.3 * Revert "bump pandas in doc.yml" This reverts commit d078cf113b552948c80b2437bc1cbf60f7ffc262. * bugfix. * fix tests. * update setup.cfg * Update whats-new --- ci/requirements/py36-bare-minimum.yml | 4 ++-- ci/requirements/py36-min-all-deps.yml | 6 +++--- ci/requirements/py36-min-nep18.yml | 2 +- doc/installing.rst | 4 ++-- doc/whats-new.rst | 8 +++++++- setup.cfg | 4 ++-- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 9 ++------- xarray/core/groupby.py | 2 +- xarray/core/nputils.py | 17 +++++------------ xarray/core/variable.py | 12 +++--------- xarray/plot/utils.py | 8 +------- xarray/tests/test_dataarray.py | 2 +- xarray/tests/test_units.py | 25 ++++++++++++++++++++----- 14 files changed, 51 insertions(+), 54 deletions(-) diff --git a/ci/requirements/py36-bare-minimum.yml b/ci/requirements/py36-bare-minimum.yml index 05186bc8748..8b604ce02dd 100644 --- a/ci/requirements/py36-bare-minimum.yml +++ b/ci/requirements/py36-bare-minimum.yml @@ -7,5 +7,5 @@ dependencies: - pytest - pytest-cov - pytest-env - - numpy=1.14 - - pandas=0.24 + - numpy=1.15 + - pandas=0.25 diff --git a/ci/requirements/py36-min-all-deps.yml b/ci/requirements/py36-min-all-deps.yml index ea707461013..dc77e232dea 100644 --- a/ci/requirements/py36-min-all-deps.yml +++ b/ci/requirements/py36-min-all-deps.yml @@ -29,8 +29,8 @@ dependencies: - nc-time-axis=1.2 - netcdf4=1.4 - numba=0.44 - - numpy=1.14 - - pandas=0.24 + - numpy=1.15 + - pandas=0.25 # - pint # See py36-min-nep18.yml - pip - pseudonetcdf=3.0 @@ -40,7 +40,7 @@ dependencies: - pytest-cov - pytest-env - rasterio=1.0 - - scipy=1.0 # Policy allows for 1.2, but scipy>=1.1 breaks numpy=1.14 + - scipy=1.3 - seaborn=0.9 # - sparse # See py36-min-nep18.yml - toolz=0.10 diff --git a/ci/requirements/py36-min-nep18.yml b/ci/requirements/py36-min-nep18.yml index fc9523ce249..8fe7644d626 100644 --- a/ci/requirements/py36-min-nep18.yml +++ b/ci/requirements/py36-min-nep18.yml @@ -9,7 +9,7 @@ dependencies: - dask=2.4 - distributed=2.4 - numpy=1.17 - - pandas=0.24 + - pandas=0.25 - pint=0.9 # Actually not enough as it doesn't implement __array_function__yet! - pytest - pytest-cov diff --git a/doc/installing.rst b/doc/installing.rst index 219cf109efe..5c39f9a3c49 100644 --- a/doc/installing.rst +++ b/doc/installing.rst @@ -7,8 +7,8 @@ Required dependencies --------------------- - Python (3.6 or later) -- `numpy `__ (1.14 or later) -- `pandas `__ (0.24 or later) +- `numpy `__ (1.15 or later) +- `pandas `__ (0.25 or later) Optional dependencies --------------------- diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c310d53190d..9fbeaa95055 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,7 +21,13 @@ v0.15.0 (unreleased) Breaking changes ~~~~~~~~~~~~~~~~ -- Bumped minimum ``dask`` version to 2.2. +- Bumped minimum tested versions for dependencies: + - numpy 1.15 + - pandas 0.25 + - dask 2.2 + - distributed 2.2 + - scipy 1.3 + - Remove ``compat`` and ``encoding`` kwargs from ``DataArray``, which have been deprecated since 0.12. (:pull:`3650`). Instead, specify the encoding when writing to disk or set diff --git a/setup.cfg b/setup.cfg index ec0fd28ef73..e336f46e68c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -79,8 +79,8 @@ zip_safe = False # https://mypy.readthedocs.io/en/latest/installed_packages.htm include_package_data = True python_requires = >=3.6 install_requires = - numpy >= 1.14 - pandas >= 0.24 + numpy >= 1.15 + pandas >= 0.25 setup_requires = setuptools_scm [options.package_data] diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 0e67a791834..4fc6773c459 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2969,7 +2969,7 @@ def quantile( See Also -------- - numpy.nanpercentile, pandas.Series.quantile, Dataset.quantile + numpy.nanquantile, pandas.Series.quantile, Dataset.quantile Examples -------- diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 82ddc8a535f..5ac79999795 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4466,12 +4466,7 @@ def _set_sparse_data_from_dataframe( idx = dataframe.index if isinstance(idx, pd.MultiIndex): - try: - codes = idx.codes - except AttributeError: - # deprecated since pandas 0.24 - codes = idx.labels - coords = np.stack([np.asarray(code) for code in codes], axis=0) + coords = np.stack([np.asarray(code) for code in idx.codes], axis=0) is_sorted = idx.is_lexsorted else: coords = np.arange(idx.size).reshape(1, -1) @@ -5171,7 +5166,7 @@ def quantile( See Also -------- - numpy.nanpercentile, pandas.Series.quantile, DataArray.quantile + numpy.nanquantile, pandas.Series.quantile, DataArray.quantile Examples -------- diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 5b52f48413d..f2a9ebac6eb 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -595,7 +595,7 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None): See Also -------- - numpy.nanpercentile, pandas.Series.quantile, Dataset.quantile, + numpy.nanquantile, pandas.Series.quantile, Dataset.quantile, DataArray.quantile Examples diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 3fe2c254b0f..dba67174fc1 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -3,6 +3,8 @@ import numpy as np import pandas as pd +from numpy.core.multiarray import normalize_axis_index + try: import bottleneck as bn @@ -13,15 +15,6 @@ _USE_BOTTLENECK = False -def _validate_axis(data, axis): - ndim = data.ndim - if not -ndim <= axis < ndim: - raise IndexError(f"axis {axis!r} out of bounds [-{ndim}, {ndim})") - if axis < 0: - axis += ndim - return axis - - def _select_along_axis(values, idx, axis): other_ind = np.ix_(*[np.arange(s) for s in idx.shape]) sl = other_ind[:axis] + (idx,) + other_ind[axis:] @@ -29,13 +22,13 @@ def _select_along_axis(values, idx, axis): def nanfirst(values, axis): - axis = _validate_axis(values, axis) + axis = normalize_axis_index(axis, values.ndim) idx_first = np.argmax(~pd.isnull(values), axis=axis) return _select_along_axis(values, idx_first, axis) def nanlast(values, axis): - axis = _validate_axis(values, axis) + axis = normalize_axis_index(axis, values.ndim) rev = (slice(None),) * axis + (slice(None, None, -1),) idx_last = -1 - np.argmax(~pd.isnull(values)[rev], axis=axis) return _select_along_axis(values, idx_last, axis) @@ -186,7 +179,7 @@ def _rolling_window(a, window, axis=-1): This function is taken from https://github.com/numpy/numpy/pull/31 but slightly modified to accept axis option. """ - axis = _validate_axis(a, axis) + axis = normalize_axis_index(axis, a.ndim) a = np.swapaxes(a, axis, -1) if window < 1: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 0a9d0767b77..74d5d57e6f6 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1722,7 +1722,7 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None): See Also -------- - numpy.nanpercentile, pandas.Series.quantile, Dataset.quantile, + numpy.nanquantile, pandas.Series.quantile, Dataset.quantile, DataArray.quantile """ @@ -1734,10 +1734,6 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None): scalar = utils.is_scalar(q) q = np.atleast_1d(np.asarray(q, dtype=np.float64)) - # TODO: remove once numpy >= 1.15.0 is the minimum requirement - if np.count_nonzero(q < 0.0) or np.count_nonzero(q > 1.0): - raise ValueError("Quantiles must be in the range [0, 1]") - if dim is None: dim = self.dims @@ -1746,9 +1742,7 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None): def _wrapper(npa, **kwargs): # move quantile axis to end. required for apply_ufunc - - # TODO: use np.nanquantile once numpy >= 1.15.0 is the minimum requirement - return np.moveaxis(np.nanpercentile(npa, **kwargs), 0, -1) + return np.moveaxis(np.nanquantile(npa, **kwargs), 0, -1) axis = np.arange(-1, -1 * len(dim) - 1, -1) result = apply_ufunc( @@ -1760,7 +1754,7 @@ def _wrapper(npa, **kwargs): output_dtypes=[np.float64], output_sizes={"quantile": len(q)}, dask="parallelized", - kwargs={"q": q * 100, "axis": axis, "interpolation": interpolation}, + kwargs={"q": q, "axis": axis, "interpolation": interpolation}, ) # for backward compatibility diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index c900dfeff3e..6eec7c6b433 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -583,10 +583,6 @@ def _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params): def _rescale_imshow_rgb(darray, vmin, vmax, robust): assert robust or vmin is not None or vmax is not None - # TODO: remove when min numpy version is bumped to 1.13 - # There's a cyclic dependency via DataArray, so we can't import from - # xarray.ufuncs in global scope. - from xarray.ufuncs import maximum, minimum # Calculate vmin and vmax automatically for `robust=True` if robust: @@ -615,9 +611,7 @@ def _rescale_imshow_rgb(darray, vmin, vmax, robust): # After scaling, downcast to 32-bit float. This substantially reduces # memory usage after we hand `darray` off to matplotlib. darray = ((darray.astype("f8") - vmin) / (vmax - vmin)).astype("f4") - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "xarray.ufuncs", PendingDeprecationWarning) - return minimum(maximum(darray, 0), 1) + return np.minimum(np.maximum(darray, 0), 1) def _update_axes( diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 962be7548bc..b9b719e8af9 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -80,7 +80,7 @@ def test_repr_multiindex(self): assert expected == repr(self.mda) @pytest.mark.skipif( - LooseVersion(np.__version__) < "1.15", + LooseVersion(np.__version__) < "1.16", reason="old versions of numpy have different printing behavior", ) def test_repr_multiindex_long(self): diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 6f9128839a7..d98e5e23516 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1838,7 +1838,10 @@ def test_squeeze(self, dtype): "func", ( method("coarsen", windows={"y": 2}, func=np.mean), - method("quantile", q=[0.25, 0.75]), + pytest.param( + method("quantile", q=[0.25, 0.75]), + marks=pytest.mark.xfail(reason="nanquantile not implemented"), + ), pytest.param( method("rank", dim="x"), marks=pytest.mark.xfail(reason="rank not implemented for non-ndarray"), @@ -3401,7 +3404,10 @@ def test_stacking_reordering(self, func, dtype): method("diff", dim="x"), method("differentiate", coord="x"), method("integrate", dim="x"), - method("quantile", q=[0.25, 0.75]), + pytest.param( + method("quantile", q=[0.25, 0.75]), + marks=pytest.mark.xfail(reason="nanquantile not implemented"), + ), method("reduce", func=np.sum, dim="x"), pytest.param( lambda x: x.dot(x), @@ -3491,7 +3497,10 @@ def test_resample(self, dtype): method("assign_coords", z=(["x"], np.arange(5) * unit_registry.s)), method("first"), method("last"), - method("quantile", q=np.array([0.25, 0.5, 0.75]), dim="x"), + pytest.param( + method("quantile", q=[0.25, 0.5, 0.75], dim="x"), + marks=pytest.mark.xfail(reason="nanquantile not implemented"), + ), ), ids=repr, ) @@ -4929,7 +4938,10 @@ def test_reindex_like(self, unit, error, dtype): method("diff", dim="x"), method("differentiate", coord="x"), method("integrate", coord="x"), - method("quantile", q=[0.25, 0.75]), + pytest.param( + method("quantile", q=[0.25, 0.75]), + marks=pytest.mark.xfail(reason="nanquantile not implemented"), + ), method("reduce", func=np.sum, dim="x"), method("map", np.fabs), ), @@ -5039,7 +5051,10 @@ def test_resample(self, dtype): method("assign_coords", v=("x", np.arange(10) * unit_registry.s)), method("first"), method("last"), - method("quantile", q=[0.25, 0.5, 0.75], dim="x"), + pytest.param( + method("quantile", q=[0.25, 0.5, 0.75], dim="x"), + marks=pytest.mark.xfail(reason="nanquantile not implemented"), + ), ), ids=repr, ) From cc142f430f9f468c990b6607ddf3424b0facf054 Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Sun, 26 Jan 2020 07:38:20 +0900 Subject: [PATCH 29/29] sel with categorical index (#3670) * Added a support with categorical index * fix from_dataframe * update a test * Added more tests * black * Added a test to make sure raising ValueErrors * remove unnecessary print * added a test for reindex * Fix according to reviews * blacken * delete trailing whitespace Co-authored-by: Deepak Cherian Co-authored-by: keewis --- doc/whats-new.rst | 2 ++ xarray/core/dataset.py | 22 +++++++----- xarray/core/indexes.py | 20 +++++++++++ xarray/core/indexing.py | 10 ++++++ xarray/tests/test_dataset.py | 65 ++++++++++++++++++++++++++++++++++++ 5 files changed, 110 insertions(+), 9 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9fbeaa95055..0e48dfcfc78 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -40,6 +40,8 @@ Breaking changes New Features ~~~~~~~~~~~~ +- :py:meth:`DataArray.sel` and :py:meth:`Dataset.sel` now support :py:class:`pandas.CategoricalIndex`. (:issue:`3669`) + By `Keisuke Fujii `_. - Support using an existing, opened h5netcdf ``File`` with :py:class:`~xarray.backends.H5NetCDFStore`. This permits creating an :py:class:`~xarray.Dataset` from a h5netcdf ``File`` that has been opened diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 5ac79999795..6f1ea76be4c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -64,6 +64,7 @@ default_indexes, isel_variable_and_index, propagate_indexes, + remove_unused_levels_categories, roll_index, ) from .indexing import is_fancy_indexer @@ -3411,7 +3412,7 @@ def ensure_stackable(val): def _unstack_once(self, dim: Hashable, fill_value, sparse) -> "Dataset": index = self.get_index(dim) - index = index.remove_unused_levels() + index = remove_unused_levels_categories(index) full_idx = pd.MultiIndex.from_product(index.levels, names=index.names) # take a shortcut in case the MultiIndex was not modified. @@ -4460,7 +4461,7 @@ def to_dataframe(self): return self._to_dataframe(self.dims) def _set_sparse_data_from_dataframe( - self, dataframe: pd.DataFrame, dims: tuple, shape: Tuple[int, ...] + self, dataframe: pd.DataFrame, dims: tuple ) -> None: from sparse import COO @@ -4468,9 +4469,11 @@ def _set_sparse_data_from_dataframe( if isinstance(idx, pd.MultiIndex): coords = np.stack([np.asarray(code) for code in idx.codes], axis=0) is_sorted = idx.is_lexsorted + shape = tuple(lev.size for lev in idx.levels) else: coords = np.arange(idx.size).reshape(1, -1) is_sorted = True + shape = (idx.size,) for name, series in dataframe.items(): # Cast to a NumPy array first, in case the Series is a pandas @@ -4495,14 +4498,16 @@ def _set_sparse_data_from_dataframe( self[name] = (dims, data) def _set_numpy_data_from_dataframe( - self, dataframe: pd.DataFrame, dims: tuple, shape: Tuple[int, ...] + self, dataframe: pd.DataFrame, dims: tuple ) -> None: idx = dataframe.index if isinstance(idx, pd.MultiIndex): # expand the DataFrame to include the product of all levels full_idx = pd.MultiIndex.from_product(idx.levels, names=idx.names) dataframe = dataframe.reindex(full_idx) - + shape = tuple(lev.size for lev in idx.levels) + else: + shape = (idx.size,) for name, series in dataframe.items(): data = np.asarray(series).reshape(shape) self[name] = (dims, data) @@ -4543,7 +4548,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas if not dataframe.columns.is_unique: raise ValueError("cannot convert DataFrame with non-unique columns") - idx = dataframe.index + idx = remove_unused_levels_categories(dataframe.index) + dataframe = dataframe.set_index(idx) obj = cls() if isinstance(idx, pd.MultiIndex): @@ -4553,17 +4559,15 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas ) for dim, lev in zip(dims, idx.levels): obj[dim] = (dim, lev) - shape = tuple(lev.size for lev in idx.levels) else: index_name = idx.name if idx.name is not None else "index" dims = (index_name,) obj[index_name] = (dims, idx) - shape = (idx.size,) if sparse: - obj._set_sparse_data_from_dataframe(dataframe, dims, shape) + obj._set_sparse_data_from_dataframe(dataframe, dims) else: - obj._set_numpy_data_from_dataframe(dataframe, dims, shape) + obj._set_numpy_data_from_dataframe(dataframe, dims) return obj def to_dask_dataframe(self, dim_order=None, set_index=False): diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 8337a0f082a..06bf08cefd2 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -9,6 +9,26 @@ from .variable import Variable +def remove_unused_levels_categories(index): + """ + Remove unused levels from MultiIndex and unused categories from CategoricalIndex + """ + if isinstance(index, pd.MultiIndex): + index = index.remove_unused_levels() + # if it contains CategoricalIndex, we need to remove unused categories + # manually. See https://github.com/pandas-dev/pandas/issues/30846 + if any(isinstance(lev, pd.CategoricalIndex) for lev in index.levels): + levels = [] + for i, level in enumerate(index.levels): + if isinstance(level, pd.CategoricalIndex): + level = level[index.codes[i]].remove_unused_categories() + levels.append(level) + index = pd.MultiIndex.from_arrays(levels, names=index.names) + elif isinstance(index, pd.CategoricalIndex): + index = index.remove_unused_categories() + return index + + class Indexes(collections.abc.Mapping): """Immutable proxy for Dataset or DataArrary indexes.""" diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 8e851b39c3e..4e58be1ad2f 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -175,6 +175,16 @@ def convert_label_indexer(index, label, index_name="", method=None, tolerance=No if label.ndim == 0: if isinstance(index, pd.MultiIndex): indexer, new_index = index.get_loc_level(label.item(), level=0) + elif isinstance(index, pd.CategoricalIndex): + if method is not None: + raise ValueError( + "'method' is not a valid kwarg when indexing using a CategoricalIndex." + ) + if tolerance is not None: + raise ValueError( + "'tolerance' is not a valid kwarg when indexing using a CategoricalIndex." + ) + indexer = index.get_loc(label.item()) else: indexer = index.get_loc( label.item(), method=method, tolerance=tolerance diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index f9eb37dbf2f..4e51e229b29 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1408,6 +1408,56 @@ def test_sel_dataarray_mindex(self): ) ) + def test_sel_categorical(self): + ind = pd.Series(["foo", "bar"], dtype="category") + df = pd.DataFrame({"ind": ind, "values": [1, 2]}) + ds = df.set_index("ind").to_xarray() + actual = ds.sel(ind="bar") + expected = ds.isel(ind=1) + assert_identical(expected, actual) + + def test_sel_categorical_error(self): + ind = pd.Series(["foo", "bar"], dtype="category") + df = pd.DataFrame({"ind": ind, "values": [1, 2]}) + ds = df.set_index("ind").to_xarray() + with pytest.raises(ValueError): + ds.sel(ind="bar", method="nearest") + with pytest.raises(ValueError): + ds.sel(ind="bar", tolerance="nearest") + + def test_categorical_index(self): + cat = pd.CategoricalIndex( + ["foo", "bar", "foo"], + categories=["foo", "bar", "baz", "qux", "quux", "corge"], + ) + ds = xr.Dataset( + {"var": ("cat", np.arange(3))}, + coords={"cat": ("cat", cat), "c": ("cat", [0, 1, 1])}, + ) + # test slice + actual = ds.sel(cat="foo") + expected = ds.isel(cat=[0, 2]) + assert_identical(expected, actual) + # make sure the conversion to the array works + actual = ds.sel(cat="foo")["cat"].values + assert (actual == np.array(["foo", "foo"])).all() + + ds = ds.set_index(index=["cat", "c"]) + actual = ds.unstack("index") + assert actual["var"].shape == (2, 2) + + def test_categorical_reindex(self): + cat = pd.CategoricalIndex( + ["foo", "bar", "baz"], + categories=["foo", "bar", "baz", "qux", "quux", "corge"], + ) + ds = xr.Dataset( + {"var": ("cat", np.arange(3))}, + coords={"cat": ("cat", cat), "c": ("cat", [0, 1, 2])}, + ) + actual = ds.reindex(cat=["foo"])["cat"].values + assert (actual == np.array(["foo"])).all() + def test_sel_drop(self): data = Dataset({"foo": ("x", [1, 2, 3])}, {"x": [0, 1, 2]}) expected = Dataset({"foo": 1}) @@ -3865,6 +3915,21 @@ def test_to_and_from_dataframe(self): expected = pd.DataFrame([[]], index=idx) assert expected.equals(actual), (expected, actual) + def test_from_dataframe_categorical(self): + cat = pd.CategoricalDtype( + categories=["foo", "bar", "baz", "qux", "quux", "corge"] + ) + i1 = pd.Series(["foo", "bar", "foo"], dtype=cat) + i2 = pd.Series(["bar", "bar", "baz"], dtype=cat) + + df = pd.DataFrame({"i1": i1, "i2": i2, "values": [1, 2, 3]}) + ds = df.set_index("i1").to_xarray() + assert len(ds["i1"]) == 3 + + ds = df.set_index(["i1", "i2"]).to_xarray() + assert len(ds["i1"]) == 2 + assert len(ds["i2"]) == 2 + @requires_sparse def test_from_dataframe_sparse(self): import sparse