Skip to content

Commit

Permalink
Allow swap_dims to take kwargs (#4841)
Browse files Browse the repository at this point in the history
  • Loading branch information
max-sixty authored Jan 24, 2021
1 parent bc35548 commit d555172
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 3 deletions.
5 changes: 4 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ Breaking changes
New Features
~~~~~~~~~~~~
- Performance improvement when constructing DataArrays. Significantly speeds up repr for Datasets with large number of variables.
By `Deepak Cherian <https://github.com/dcherian>`_
By `Deepak Cherian <https://github.com/dcherian>`_.
- :py:meth:`DataArray.swap_dims` & :py:meth:`Dataset.swap_dims` now accept dims
in the form of kwargs as well as a dict, like most similar methods.
By `Maximilian Roos <https://github.com/max-sixty>`_.

Bug fixes
~~~~~~~~~
Expand Down
9 changes: 8 additions & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,7 +1699,9 @@ def rename(
new_name_or_name_dict = cast(Hashable, new_name_or_name_dict)
return self._replace(name=new_name_or_name_dict)

def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray":
def swap_dims(
self, dims_dict: Mapping[Hashable, Hashable] = None, **dims_kwargs
) -> "DataArray":
"""Returns a new DataArray with swapped dimensions.
Parameters
Expand All @@ -1708,6 +1710,10 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray":
Dictionary whose keys are current dimension names and whose values
are new names.
**dim_kwargs : {dim: , ...}, optional
The keyword arguments form of ``dims_dict``.
One of dims_dict or dims_kwargs must be provided.
Returns
-------
swapped : DataArray
Expand Down Expand Up @@ -1749,6 +1755,7 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray":
DataArray.rename
Dataset.swap_dims
"""
dims_dict = either_dict_or_kwargs(dims_dict, dims_kwargs, "swap_dims")
ds = self._to_temp_dataset().swap_dims(dims_dict)
return self._from_temp_dataset(ds)

Expand Down
10 changes: 9 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3155,7 +3155,9 @@ def rename_vars(
)
return self._replace(variables, coord_names, dims=dims, indexes=indexes)

def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "Dataset":
def swap_dims(
self, dims_dict: Mapping[Hashable, Hashable] = None, **dims_kwargs
) -> "Dataset":
"""Returns a new object with swapped dimensions.
Parameters
Expand All @@ -3164,6 +3166,10 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "Dataset":
Dictionary whose keys are current dimension names and whose values
are new names.
**dim_kwargs : {existing_dim: new_dim, ...}, optional
The keyword arguments form of ``dims_dict``.
One of dims_dict or dims_kwargs must be provided.
Returns
-------
swapped : Dataset
Expand Down Expand Up @@ -3214,6 +3220,8 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "Dataset":
"""
# TODO: deprecate this method in favor of a (less confusing)
# rename_dims() method that only renames dimensions.

dims_dict = either_dict_or_kwargs(dims_dict, dims_kwargs, "swap_dims")
for k, v in dims_dict.items():
if k not in self.dims:
raise ValueError(
Expand Down
10 changes: 10 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1639,6 +1639,16 @@ def test_swap_dims(self):
expected.indexes[dim_name], actual.indexes[dim_name]
)

# as kwargs
array = DataArray(np.random.randn(3), {"x": list("abc")}, "x")
expected = DataArray(array.values, {"x": ("y", list("abc"))}, dims="y")
actual = array.swap_dims(x="y")
assert_identical(expected, actual)
for dim_name in set().union(expected.indexes.keys(), actual.indexes.keys()):
pd.testing.assert_index_equal(
expected.indexes[dim_name], actual.indexes[dim_name]
)

# multiindex case
idx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"])
array = DataArray(np.random.randn(3), {"y": ("x", idx)}, "x")
Expand Down
7 changes: 7 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2748,6 +2748,13 @@ def test_swap_dims(self):
actual = original.swap_dims({"x": "u"})
assert_identical(expected, actual)

# as kwargs
expected = Dataset(
{"y": ("u", list("abc")), "z": 42}, coords={"x": ("u", [1, 2, 3])}
)
actual = original.swap_dims(x="u")
assert_identical(expected, actual)

# handle multiindex case
idx = pd.MultiIndex.from_arrays([list("aab"), list("yzz")], names=["y1", "y2"])
original = Dataset({"x": [1, 2, 3], "y": ("x", idx), "z": 42})
Expand Down

0 comments on commit d555172

Please sign in to comment.