diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d9c85a19ce6..96f0ba9a4a6 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -111,7 +111,7 @@ Internal Changes ~~~~~~~~~~~~~~~~ - Added integration tests against `pint `_. - (:pull:`3238`, :pull:`3447`) by `Justus Magin `_. + (:pull:`3238`, :pull:`3447`, :pull:`3508`) by `Justus Magin `_. .. note:: diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 8eed1f0dbe3..fd9e9b039ac 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1045,6 +1045,36 @@ def test_comparisons(self, func, variation, unit, dtype): assert expected == result + @pytest.mark.xfail(reason="blocked by `where`") + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + def test_broadcast_like(self, unit, dtype): + array1 = np.linspace(1, 2, 2 * 1).reshape(2, 1).astype(dtype) * unit_registry.Pa + array2 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * unit_registry.Pa + + x1 = np.arange(2) * unit_registry.m + x2 = np.arange(2) * unit + y1 = np.array([0]) * unit_registry.m + y2 = np.arange(3) * unit + + arr1 = xr.DataArray(data=array1, coords={"x": x1, "y": y1}, dims=("x", "y")) + arr2 = xr.DataArray(data=array2, coords={"x": x2, "y": y2}, dims=("x", "y")) + + expected = attach_units( + strip_units(arr1).broadcast_like(strip_units(arr2)), extract_units(arr1) + ) + result = arr1.broadcast_like(arr2) + + assert_equal_with_units(expected, result) + @pytest.mark.parametrize( "unit", ( @@ -1303,6 +1333,49 @@ def test_squeeze(self, shape, dtype): np.squeeze(array, axis=index), data_array.squeeze(dim=name) ) + @pytest.mark.xfail( + reason="indexes strip units and head / tail / thin only support integers" + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "func", + (method("head", x=7, y=3), method("tail", x=7, y=3), method("thin", x=7, y=3)), + ids=repr, + ) + def test_head_tail_thin(self, func, unit, error, dtype): + array = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK + + coords = { + "x": np.arange(10) * unit_registry.m, + "y": np.arange(5) * unit_registry.m, + } + + arr = xr.DataArray(data=array, coords=coords, dims=("x", "y")) + + kwargs = {name: value * unit for name, value in func.kwargs.items()} + + if error is not None: + with pytest.raises(error): + func(arr, **kwargs) + + return + + expected = attach_units(func(strip_units(arr)), extract_units(arr)) + result = func(arr, **kwargs) + + assert_equal_with_units(expected, result) + @pytest.mark.parametrize( "unit,error", ( @@ -2472,6 +2545,40 @@ def test_comparisons(self, func, variation, unit, dtype): assert expected == result + @pytest.mark.xfail(reason="blocked by `where`") + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + def test_broadcast_like(self, unit, dtype): + array1 = np.linspace(1, 2, 2 * 1).reshape(2, 1).astype(dtype) * unit_registry.Pa + array2 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * unit_registry.Pa + + x1 = np.arange(2) * unit_registry.m + x2 = np.arange(2) * unit + y1 = np.array([0]) * unit_registry.m + y2 = np.arange(3) * unit + + ds1 = xr.Dataset( + data_vars={"a": (("x", "y"), array1)}, coords={"x": x1, "y": y1} + ) + ds2 = xr.Dataset( + data_vars={"a": (("x", "y"), array2)}, coords={"x": x2, "y": y2} + ) + + expected = attach_units( + strip_units(ds1).broadcast_like(strip_units(ds2)), extract_units(ds1) + ) + result = ds1.broadcast_like(ds2) + + assert_equal_with_units(expected, result) + @pytest.mark.parametrize( "unit", (