Skip to content
This repository has been archived by the owner on Jun 10, 2020. It is now read-only.

Commit

Permalink
ENH: add basic typing for np.ufunc (#44)
Browse files Browse the repository at this point in the history
* ENH: add basic typing for `np.ufunc`

This adds basic type hints for `np.ufunc` and types all ufuncs in the
top-level namespace as `np.ufunc`. Ufuncs are highly dynamic, e.g.

- Their call signatures can vary
- The `reduce`/`accumulate`/`reduceat`/`outer`/`at` methods may or may
  not always raise

so it is difficult to have precise types. A path forward would be
writing a mypy plugin to get more precise typing.

* MAINT: fix return type for `ufunc.__call__`

Co-authored-by: Stephan Hoyer <shoyer@gmail.com>
  • Loading branch information
person142 and shoyer authored Apr 16, 2020
1 parent 2162156 commit 8bf2d45
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ __pycache__
numpy_stubs.egg-info/
venv
.idea
*~
**~
156 changes: 156 additions & 0 deletions numpy-stubs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ from numpy.core._internal import _ctypes
from typing import (
Any,
ByteString,
Callable,
Container,
Callable,
Dict,
Expand Down Expand Up @@ -618,5 +619,160 @@ WRAP: int
little_endian: int
tracemalloc_domain: int

class ufunc:
def __call__(
self,
*args: _ArrayLike,
out: Optional[Union[ndarray, Tuple[ndarray, ...]]] = ...,
where: Optional[ndarray] = ...,
# The list should be a list of tuples of ints, but since we
# don't know the signature it would need to be
# Tuple[int, ...]. But, since List is invariant something like
# e.g. List[Tuple[int, int]] isn't a subtype of
# List[Tuple[int, ...]], so we can't type precisely here.
axes: List[Any] = ...,
axis: int = ...,
keepdims: bool = ...,
# TODO: make this precise when we can use Literal.
casting: str = ...,
# TODO: make this precise when we can use Literal.
order: Optional[str] = ...,
dtype: Optional[_DtypeLike] = ...,
subok: bool = ...,
signature: Union[str, Tuple[str]] = ...,
# In reality this should be a length of list 3 containing an
# int, an int, and a callable, but there's no way to express
# that.
extobj: List[Union[int, Callable]] = ...,
) -> Union[ndarray, generic]: ...
@property
def nin(self) -> int: ...
@property
def nout(self) -> int: ...
@property
def nargs(self) -> int: ...
@property
def ntypes(self) -> int: ...
@property
def types(self) -> List[str]: ...
# Broad return type because it has to encompass things like
#
# >>> np.logical_and.identity is True
# True
# >>> np.add.identity is 0
# True
# >>> np.sin.identity is None
# True
#
# and any user-defined ufuncs.
@property
def identity(self) -> Any: ...
# This is None for ufuncs and a string for gufuncs.
@property
def signature(self) -> Optional[str]: ...
# The next four methods will always exist, but they will just
# raise a ValueError ufuncs with that don't accept two input
# arguments and return one output argument. Because of that we
# can't type them very precisely.
@property
def reduce(self) -> Any: ...
@property
def accumulate(self) -> Any: ...
@property
def reduceat(self) -> Any: ...
@property
def outer(self) -> Any: ...
# Similarly at won't be defined for ufuncs that return multiple
# outputs, so we can't type it very precisely.
@property
def at(self) -> Any: ...

absolute: ufunc
add: ufunc
arccos: ufunc
arccosh: ufunc
arcsin: ufunc
arcsinh: ufunc
arctan2: ufunc
arctan: ufunc
arctanh: ufunc
bitwise_and: ufunc
bitwise_or: ufunc
bitwise_xor: ufunc
cbrt: ufunc
ceil: ufunc
conjugate: ufunc
copysign: ufunc
cos: ufunc
cosh: ufunc
deg2rad: ufunc
degrees: ufunc
divmod: ufunc
equal: ufunc
exp2: ufunc
exp: ufunc
expm1: ufunc
fabs: ufunc
float_power: ufunc
floor: ufunc
floor_divide: ufunc
fmax: ufunc
fmin: ufunc
fmod: ufunc
frexp: ufunc
gcd: ufunc
greater: ufunc
greater_equal: ufunc
heaviside: ufunc
hypot: ufunc
invert: ufunc
isfinite: ufunc
isinf: ufunc
isnan: ufunc
isnat: ufunc
lcm: ufunc
ldexp: ufunc
left_shift: ufunc
less: ufunc
less_equal: ufunc
log10: ufunc
log1p: ufunc
log2: ufunc
log: ufunc
logaddexp2: ufunc
logaddexp: ufunc
logical_and: ufunc
logical_not: ufunc
logical_or: ufunc
logical_xor: ufunc
matmul: ufunc
maximum: ufunc
minimum: ufunc
modf: ufunc
multiply: ufunc
negative: ufunc
nextafter: ufunc
not_equal: ufunc
positive: ufunc
power: ufunc
rad2deg: ufunc
radians: ufunc
reciprocal: ufunc
remainder: ufunc
right_shift: ufunc
rint: ufunc
sign: ufunc
signbit: ufunc
sin: ufunc
sinh: ufunc
spacing: ufunc
sqrt: ufunc
square: ufunc
subtract: ufunc
tan: ufunc
tanh: ufunc
true_divide: ufunc
trunc: ufunc

# TODO(shoyer): remove when the full numpy namespace is defined
def __getattr__(name: str) -> Any: ...
21 changes: 21 additions & 0 deletions scripts/find_ufuncs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np


def main():
ufuncs = []
for obj_name in np.__dir__():
obj = getattr(np, obj_name)
if isinstance(obj, np.ufunc):
ufuncs.append(obj)

ufunc_stubs = []
for ufunc in set(ufuncs):
ufunc_stubs.append(f'{ufunc.__name__}: ufunc')
ufunc_stubs.sort()

for stub in ufunc_stubs:
print(stub)


if __name__ == '__main__':
main()
5 changes: 5 additions & 0 deletions tests/fail/ufuncs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import numpy as np

np.sin.nin + 'foo' # E: Unsupported operand types
np.sin(1, foo='bar') # E: Unexpected keyword argument
np.sin(1, extobj=['foo', 'foo', 'foo']) # E: incompatible type
14 changes: 14 additions & 0 deletions tests/pass/ufuncs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import numpy as np

np.sin(1)
np.sin([1, 2, 3])
np.sin(1, out=np.empty(1))
np.matmul(
np.ones((2, 2, 2)),
np.ones((2, 2, 2)),
axes=[(0, 1), (0, 1), (0, 1)],
)
np.sin(1, signature='D')
np.sin(1, extobj=[16, 1, lambda: None])
np.sin(1) + np.sin(1)
np.sin.types[0]

0 comments on commit 8bf2d45

Please sign in to comment.