From 11d65f99765ec8ff305d6bd0cbc93dcb5f3d7b79 Mon Sep 17 00:00:00 2001 From: Gerrit Holl Date: Mon, 19 Mar 2018 16:35:22 +0000 Subject: [PATCH] Make UADA work with __array_ufunc__ Some of xarray and numpy work with __array_ufunc__ nowadays. Make sure this is supported by UADA to keep up to date. --- typhon/physics/units/tools.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/typhon/physics/units/tools.py b/typhon/physics/units/tools.py index ef32bbaf..da4ea664 100644 --- a/typhon/physics/units/tools.py +++ b/typhon/physics/units/tools.py @@ -18,6 +18,9 @@ class UnitsAwareDataArray(xarray.DataArray): """Like xarray.DataArray, but transfers units """ + # need to keep both __array_wrap__ and __array_ufunc__. Although the + # former supersedes the latter, xarrays methods explicitly call the + # former sometimes. def __array_wrap__(self, obj, context=None): new_var = super().__array_wrap__(obj, context) if self.attrs.get("units"): @@ -56,6 +59,33 @@ def _apply_rbinary_op_to_units(self, func, other, x): ureg.Quantity(1, self.attrs["units"]),).u) return x + def __array_ufunc__(self, ufunc, method, *args, **kwargs): + new_var = super().__array_ufunc__(ufunc, method, *args, **kwargs) + # make sure we're still UADA + new_var = self.__class__(new_var) + if self.attrs.get("units"): + if method == "__call__": + q = ufunc(ureg.Quantity(1, self.attrs.get("units"))) + try: + u = q.u + except AttributeError: + if (ureg(self.attrs["units"]).dimensionless or + new_var.dtype.kind == "b"): + # expected, see https://github.com/hgrecco/pint/issues/482 + u = ureg.dimensionless + else: + raise + # for exp and log, values are not set correctly. I'm + # not sure why. Perhaps related to + # https://github.com/hgrecco/pint/issues/493 + new_var.values = ufunc(ureg.Quantity(self.values, self.units)) + new_var.attrs["units"] = str(u) + else: # unary operators? always retain units? + raise NotImplementedError("Not implented") + new_var.attrs["units"] = str(self.attrs.get("units")) + + return new_var + # pow is different because resulting unit depends on argument, not on # unit of argument def __pow__(self, other):