Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add np.pad and np.resize implementations #956

Merged
merged 1 commit into from
Dec 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ The following ufuncs_ can be applied to a Quantity object:

And the following NumPy functions:

- alen, amax, amin, append, argmax, argmin, argsort, around, atleast_1d, atleast_2d, atleast_3d, average, block, broadcast_to, clip, column_stack, compress, concatenate, copy, copyto, count_nonzero, cross, cumprod, cumproduct, cumsum, diagonal, diff, dot, dstack, ediff1d, einsum, empty_like, expand_dims, fix, flip, full_like, gradient, hstack, insert, interp, isclose, iscomplex, isin, isreal, linspace, mean, median, meshgrid, moveaxis, nan_to_num, nanargmax, nanargmin, nancumprod, nancumsum, nanmax, nanmean, nanmedian, nanmin, nanpercentile, nanstd, nanvar, ndim, nonzero, ones_like, percentile, ptp, ravel, result_type, rollaxis, rot90, round\_, searchsorted, shape, size, sort, squeeze, stack, std, sum, swapaxes, tile, transpose, trapz, trim_zeros, unwrap, var, vstack, where, zeros_like
- alen, amax, amin, append, argmax, argmin, argsort, around, atleast_1d, atleast_2d, atleast_3d, average, block, broadcast_to, clip, column_stack, compress, concatenate, copy, copyto, count_nonzero, cross, cumprod, cumproduct, cumsum, diagonal, diff, dot, dstack, ediff1d, einsum, empty_like, expand_dims, fix, flip, full_like, gradient, hstack, insert, interp, isclose, iscomplex, isin, isreal, linspace, mean, median, meshgrid, moveaxis, nan_to_num, nanargmax, nanargmin, nancumprod, nancumsum, nanmax, nanmean, nanmedian, nanmin, nanpercentile, nanstd, nanvar, ndim, nonzero, ones_like, pad, percentile, ptp, ravel, resize, result_type, rollaxis, rot90, round\_, searchsorted, shape, size, sort, squeeze, stack, std, sum, swapaxes, tile, transpose, trapz, trim_zeros, unwrap, var, vstack, where, zeros_like

And the following `NumPy ndarray methods`_:

Expand Down
26 changes: 26 additions & 0 deletions pint/numpy_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,31 @@ def _isin(element, test_elements, assume_unique=False, invert=False):
return np.isin(element.m, test_elements, assume_unique=assume_unique, invert=invert)


@implements("pad", "function")
def _pad(array, pad_width, mode="constant", **kwargs):
def _recursive_convert(arg, unit):
if iterable(arg):
return tuple(_recursive_convert(a, unit=unit) for a in arg)
elif _is_quantity(arg):
return arg.m_as(unit)
else:
return arg

# pad only dispatches on array argument, so we know it is a Quantity
units = array.units

# Handle flexible constant_values and end_values, converting to units if Quantity
# and ignoring if not
if mode == "constant":
kwargs["constant_values"] = _recursive_convert(kwargs["constant_values"], units)
elif mode == "linear_ramp":
kwargs["end_values"] = _recursive_convert(kwargs["end_values"], units)

return units._REGISTRY.Quantity(
np.pad(array._magnitude, pad_width, mode=mode, **kwargs), units
)


# Implement simple matching-unit or stripped-unit functions based on signature


Expand Down Expand Up @@ -675,6 +700,7 @@ def implementation(*args, **kwargs):
("tile", "A", True),
("rot90", "m", True),
("insert", ["arr", "values"], True),
("resize", "a", True),
]:
implement_consistent_units_by_argument(func_str, unit_arguments, wrap_output)

Expand Down
96 changes: 96 additions & 0 deletions pint/testsuite/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,102 @@ def test_array_protocol_unavailable(self):
for attr in ("__array_struct__", "__array_interface__"):
self.assertRaises(AttributeError, getattr, self.q, attr)

@helpers.requires_array_function_protocol()
def test_resize(self):
self.assertQuantityEqual(
np.resize(self.q, (2, 4)), [[1, 2, 3, 4], [1, 2, 3, 4]] * self.ureg.m
)

@helpers.requires_array_function_protocol()
def test_pad(self):
# Tests reproduced with modification from NumPy documentation
a = [1, 2, 3, 4, 5] * self.ureg.m
self.assertQuantityEqual(
np.pad(a, (2, 3), "constant", constant_values=(4, 600 * self.ureg.cm)),
[4, 4, 1, 2, 3, 4, 5, 6, 6, 6] * self.ureg.m,
)
self.assertQuantityEqual(
np.pad(a, (2, 3), "edge"), [1, 1, 1, 2, 3, 4, 5, 5, 5, 5] * self.ureg.m
)
self.assertQuantityEqual(
np.pad(a, (2, 3), "linear_ramp", end_values=(5, -4) * self.ureg.m),
[5, 3, 1, 2, 3, 4, 5, 2, -1, -4] * self.ureg.m,
)
self.assertQuantityEqual(
np.pad(a, (2,), "maximum"), [5, 5, 1, 2, 3, 4, 5, 5, 5] * self.ureg.m
)
self.assertQuantityEqual(
np.pad(a, (2,), "mean"), [3, 3, 1, 2, 3, 4, 5, 3, 3] * self.ureg.m
)
self.assertQuantityEqual(
np.pad(a, (2,), "median"), [3, 3, 1, 2, 3, 4, 5, 3, 3] * self.ureg.m
)
self.assertQuantityEqual(
np.pad(self.q, ((3, 2), (2, 3)), "minimum"),
[
[1, 1, 1, 2, 1, 1, 1],
[1, 1, 1, 2, 1, 1, 1],
[1, 1, 1, 2, 1, 1, 1],
[1, 1, 1, 2, 1, 1, 1],
[3, 3, 3, 4, 3, 3, 3],
[1, 1, 1, 2, 1, 1, 1],
[1, 1, 1, 2, 1, 1, 1],
]
* self.ureg.m,
)
self.assertQuantityEqual(
np.pad(a, (2, 3), "reflect"), [3, 2, 1, 2, 3, 4, 5, 4, 3, 2] * self.ureg.m
)
self.assertQuantityEqual(
np.pad(a, (2, 3), "reflect", reflect_type="odd"),
[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8] * self.ureg.m,
)
self.assertQuantityEqual(
np.pad(a, (2, 3), "symmetric"), [2, 1, 1, 2, 3, 4, 5, 5, 4, 3] * self.ureg.m
)
self.assertQuantityEqual(
np.pad(a, (2, 3), "symmetric", reflect_type="odd"),
[0, 1, 1, 2, 3, 4, 5, 5, 6, 7] * self.ureg.m,
)
self.assertQuantityEqual(
np.pad(a, (2, 3), "wrap"), [4, 5, 1, 2, 3, 4, 5, 1, 2, 3] * self.ureg.m
)

def pad_with(vector, pad_width, iaxis, kwargs):
pad_value = kwargs.get("padder", 10)
vector[: pad_width[0]] = pad_value
vector[-pad_width[1] :] = pad_value

b = self.Q_(np.arange(6).reshape((2, 3)), "degC")
self.assertQuantityEqual(
np.pad(b, 2, pad_with),
self.Q_(
[
[10, 10, 10, 10, 10, 10, 10],
[10, 10, 10, 10, 10, 10, 10],
[10, 10, 0, 1, 2, 10, 10],
[10, 10, 3, 4, 5, 10, 10],
[10, 10, 10, 10, 10, 10, 10],
[10, 10, 10, 10, 10, 10, 10],
],
"degC",
),
)
self.assertQuantityEqual(
np.pad(b, 2, pad_with, padder=100),
self.Q_(
[
[100, 100, 100, 100, 100, 100, 100],
[100, 100, 100, 100, 100, 100, 100],
[100, 100, 0, 1, 2, 100, 100],
[100, 100, 3, 4, 5, 100, 100],
[100, 100, 100, 100, 100, 100, 100],
[100, 100, 100, 100, 100, 100, 100],
],
"degC",
),
) # Note: Does not support Quantity pad_with vectorized callable use


@unittest.skip
class TestBitTwiddlingUfuncs(TestUFuncs):
Expand Down