Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgsavage committed Mar 18, 2019
1 parent e32e6e9 commit 4dcbe78
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 46 deletions.
9 changes: 7 additions & 2 deletions pint/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,12 @@ def implement_func(func_str, pre_calc_units_, post_calc_units_, out_units_):
"""
func = getattr(np,func_str)
print(func_str)

@implements(func)
def _(*args, **kwargs):
# TODO make work for kwargs
print(func_str)
print("_",func_str)
(pre_calc_units, post_calc_units, out_units)=(pre_calc_units_, post_calc_units_, out_units_)
first_input_units=_get_first_input_units(args, kwargs)
if pre_calc_units == "consistent_infer":
Expand Down Expand Up @@ -190,7 +191,10 @@ def _(*args, **kwargs):
elif out_units == "infer_from_input":
out_units = first_input_units
return post_calc_Q_.to(out_units)

@implements(np.power)
def _power(*args, **kwargs):
print(args)
pass
for func_str in ['linspace', 'concatenate', 'block', 'stack', 'hstack', 'vstack', 'dstack', 'atleast_1d', 'column_stack', 'atleast_2d', 'atleast_3d', 'expand_dims','squeeze', 'swapaxes', 'compress', 'searchsorted' ,'rollaxis', 'broadcast_to', 'moveaxis', 'diff', 'ediff1d', 'fix']:
implement_func(func_str, 'consistent_infer', 'as_pre_calc', 'as_post_calc')

Expand Down Expand Up @@ -234,6 +238,7 @@ class BaseQuantity(PrettyIPython, SharedRegistryObject):
:type units: UnitsContainer, str or Quantity.
"""
def __array_function__(self, func, types, args, kwargs):
print("__array_function__", func)
if func not in HANDLED_FUNCTIONS:
return NotImplemented
if not all(issubclass(t, BaseQuantity) for t in types):
Expand Down
73 changes: 29 additions & 44 deletions pint/testsuite/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,38 @@ def test_cross(self):
def test_trapz(self):
self.assertQuantityEqual(np.trapz([1. ,2., 3., 4.] * self.ureg.J, dx=1*self.ureg.m), 7.5 * self.ureg.J*self.ureg.m)
# Arithmetic operations
@unittest.expectedFailure

def test_power(self):
"""This is not supported as different elements might end up with different units
arr = np.array(range(3), dtype=np.float)
q = self.Q_(arr, 'meter')

eg. ([1, 1] * m) ** [2, 3]
for op_ in [op.pow, op.ipow, np.power]:
q_cp = copy.copy(q)
self.assertRaises(DimensionalityError, op_, 2., q_cp)
arr_cp = copy.copy(arr)
arr_cp = copy.copy(arr)
q_cp = copy.copy(q)
self.assertRaises(DimensionalityError, op_, q_cp, arr_cp)
q_cp = copy.copy(q)
q2_cp = copy.copy(q)
self.assertRaises(DimensionalityError, op_, q_cp, q2_cp)

Must force exponent to single value
"""
self._test2(np.power, self.q1,
(self.qless, np.asarray([1., 2, 3, 4])),
(self.q2, ),)
@unittest.expectedFailure
@helpers.requires_numpy()
def test_exponentiation_array_exp_2(self):
arr = np.array(range(3), dtype=np.float)
#q = self.Q_(copy.copy(arr), None)
q = self.Q_(copy.copy(arr), 'meter')
arr_cp = copy.copy(arr)
q_cp = copy.copy(q)
# this fails as expected since numpy 1.8.0 but...
self.assertRaises(DimensionalityError, op.pow, arr_cp, q_cp)
# ..not for op.ipow !
# q_cp is treated as if it is an array. The units are ignored.
# BaseQuantity.__ipow__ is never called
arr_cp = copy.copy(arr)
q_cp = copy.copy(q)
self.assertRaises(DimensionalityError, op.ipow, arr_cp, q_cp)

class TestNumpyUnclassified(TestNumpyMethods):
def test_tolist(self):
Expand Down Expand Up @@ -493,39 +514,3 @@ def test_right_shift(self):
(self.qless, 2),
(self.q1, self.q2, self.qs, ),
'same')


class TestNDArrayQuantityMath(QuantityTestCase):

@helpers.requires_numpy()
def test_exponentiation_array_exp(self):
arr = np.array(range(3), dtype=np.float)
q = self.Q_(arr, 'meter')

for op_ in [op.pow, op.ipow]:
q_cp = copy.copy(q)
self.assertRaises(DimensionalityError, op_, 2., q_cp)
arr_cp = copy.copy(arr)
arr_cp = copy.copy(arr)
q_cp = copy.copy(q)
self.assertRaises(DimensionalityError, op_, q_cp, arr_cp)
q_cp = copy.copy(q)
q2_cp = copy.copy(q)
self.assertRaises(DimensionalityError, op_, q_cp, q2_cp)

@unittest.expectedFailure
@helpers.requires_numpy()
def test_exponentiation_array_exp_2(self):
arr = np.array(range(3), dtype=np.float)
#q = self.Q_(copy.copy(arr), None)
q = self.Q_(copy.copy(arr), 'meter')
arr_cp = copy.copy(arr)
q_cp = copy.copy(q)
# this fails as expected since numpy 1.8.0 but...
self.assertRaises(DimensionalityError, op.pow, arr_cp, q_cp)
# ..not for op.ipow !
# q_cp is treated as if it is an array. The units are ignored.
# BaseQuantity.__ipow__ is never called
arr_cp = copy.copy(arr)
q_cp = copy.copy(q)
self.assertRaises(DimensionalityError, op.ipow, arr_cp, q_cp)

0 comments on commit 4dcbe78

Please sign in to comment.