Skip to content

Commit

Permalink
add missing pint integration tests (#3508)
Browse files Browse the repository at this point in the history
* add tests for broadcast_like

* add tests for DataArray head / tail / thin

* update whats-new.rst
  • Loading branch information
keewis authored and max-sixty committed Nov 10, 2019
1 parent f14edf3 commit 4e9240a
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 1 deletion.
2 changes: 1 addition & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Internal Changes
~~~~~~~~~~~~~~~~

- Added integration tests against `pint <https://pint.readthedocs.io/>`_.
(:pull:`3238`, :pull:`3447`) by `Justus Magin <https://github.com/keewis>`_.
(:pull:`3238`, :pull:`3447`, :pull:`3508`) by `Justus Magin <https://github.com/keewis>`_.

.. note::

Expand Down
107 changes: 107 additions & 0 deletions xarray/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,36 @@ def test_comparisons(self, func, variation, unit, dtype):

assert expected == result

@pytest.mark.xfail(reason="blocked by `where`")
@pytest.mark.parametrize(
"unit",
(
pytest.param(1, id="no_unit"),
pytest.param(unit_registry.dimensionless, id="dimensionless"),
pytest.param(unit_registry.s, id="incompatible_unit"),
pytest.param(unit_registry.cm, id="compatible_unit"),
pytest.param(unit_registry.m, id="identical_unit"),
),
)
def test_broadcast_like(self, unit, dtype):
array1 = np.linspace(1, 2, 2 * 1).reshape(2, 1).astype(dtype) * unit_registry.Pa
array2 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * unit_registry.Pa

x1 = np.arange(2) * unit_registry.m
x2 = np.arange(2) * unit
y1 = np.array([0]) * unit_registry.m
y2 = np.arange(3) * unit

arr1 = xr.DataArray(data=array1, coords={"x": x1, "y": y1}, dims=("x", "y"))
arr2 = xr.DataArray(data=array2, coords={"x": x2, "y": y2}, dims=("x", "y"))

expected = attach_units(
strip_units(arr1).broadcast_like(strip_units(arr2)), extract_units(arr1)
)
result = arr1.broadcast_like(arr2)

assert_equal_with_units(expected, result)

@pytest.mark.parametrize(
"unit",
(
Expand Down Expand Up @@ -1303,6 +1333,49 @@ def test_squeeze(self, shape, dtype):
np.squeeze(array, axis=index), data_array.squeeze(dim=name)
)

@pytest.mark.xfail(
reason="indexes strip units and head / tail / thin only support integers"
)
@pytest.mark.parametrize(
"unit,error",
(
pytest.param(1, DimensionalityError, id="no_unit"),
pytest.param(
unit_registry.dimensionless, DimensionalityError, id="dimensionless"
),
pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
pytest.param(unit_registry.cm, None, id="compatible_unit"),
pytest.param(unit_registry.m, None, id="identical_unit"),
),
)
@pytest.mark.parametrize(
"func",
(method("head", x=7, y=3), method("tail", x=7, y=3), method("thin", x=7, y=3)),
ids=repr,
)
def test_head_tail_thin(self, func, unit, error, dtype):
array = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK

coords = {
"x": np.arange(10) * unit_registry.m,
"y": np.arange(5) * unit_registry.m,
}

arr = xr.DataArray(data=array, coords=coords, dims=("x", "y"))

kwargs = {name: value * unit for name, value in func.kwargs.items()}

if error is not None:
with pytest.raises(error):
func(arr, **kwargs)

return

expected = attach_units(func(strip_units(arr)), extract_units(arr))
result = func(arr, **kwargs)

assert_equal_with_units(expected, result)

@pytest.mark.parametrize(
"unit,error",
(
Expand Down Expand Up @@ -2472,6 +2545,40 @@ def test_comparisons(self, func, variation, unit, dtype):

assert expected == result

@pytest.mark.xfail(reason="blocked by `where`")
@pytest.mark.parametrize(
"unit",
(
pytest.param(1, id="no_unit"),
pytest.param(unit_registry.dimensionless, id="dimensionless"),
pytest.param(unit_registry.s, id="incompatible_unit"),
pytest.param(unit_registry.cm, id="compatible_unit"),
pytest.param(unit_registry.m, id="identical_unit"),
),
)
def test_broadcast_like(self, unit, dtype):
array1 = np.linspace(1, 2, 2 * 1).reshape(2, 1).astype(dtype) * unit_registry.Pa
array2 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * unit_registry.Pa

x1 = np.arange(2) * unit_registry.m
x2 = np.arange(2) * unit
y1 = np.array([0]) * unit_registry.m
y2 = np.arange(3) * unit

ds1 = xr.Dataset(
data_vars={"a": (("x", "y"), array1)}, coords={"x": x1, "y": y1}
)
ds2 = xr.Dataset(
data_vars={"a": (("x", "y"), array2)}, coords={"x": x2, "y": y2}
)

expected = attach_units(
strip_units(ds1).broadcast_like(strip_units(ds2)), extract_units(ds1)
)
result = ds1.broadcast_like(ds2)

assert_equal_with_units(expected, result)

@pytest.mark.parametrize(
"unit",
(
Expand Down

0 comments on commit 4e9240a

Please sign in to comment.