Skip to content

Commit

Permalink
jnp.mask_indices: add docs & tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 27, 2024
1 parent ff0a98a commit cf91e76
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 13 deletions.
60 changes: 47 additions & 13 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6903,19 +6903,53 @@ def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int
return reductions.sum(a, axis=(-2, -1), dtype=dtype)


def _wrap_indices_function(f):
@util.implements(f, update_doc=False)
def wrapper(*args, **kwargs):
args = [core.concrete_or_error(
None, arg, f"argument {i} of jnp.{f.__name__}()")
for i, arg in enumerate(args)]
kwargs = {key: core.concrete_or_error(
None, val, f"argument '{key}' of jnp.{f.__name__}()")
for key, val in kwargs.items()}
return tuple(asarray(x) for x in f(*args, **kwargs))
return wrapper

mask_indices = _wrap_indices_function(np.mask_indices)
def mask_indices(n: int,
mask_func: Callable[[ArrayLike, int], Array],
k: int = 0, *, size: int | None = None) -> tuple[Array, Array]:
"""Return indices of a mask of an (n, n) array.
Args:
n: static integer array dimension.
mask_func: a function that takes a shape ``(n, n)`` array and
an optional offset ``k``, and returns a shape ``(n, n)`` mask.
Examples of functions with this signature are
:func:`~jax.numpy.triu` and :func:`~jax.numpy.tril`.
k: a scalar value passed to ``mask_func``.
size: optional argument specifying the static size of the output arrays.
This is passed to :func:`~jax.numpy.nonzero` when generating the indices
from the mask.
Returns:
a tuple of indices where ``mask_func`` is nonzero.
See also:
- :func:`jax.numpy.triu_indices`: compute ``mask_indices`` for :func:`~jax.numpy.triu`.
- :func:`jax.numpy.tril_indices`: compute ``mask_indices`` for :func:`~jax.numpy.tril`.
Examples:
Calling ``mask_indices`` on built-in masking functions:
>>> jnp.mask_indices(3, jnp.triu)
(Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))
>>> jnp.mask_indices(3, jnp.tril)
(Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))
Calling ``mask_indices`` on a custom masking function:
>>> def mask_func(x, k=0):
... i = jnp.arange(x.shape[0])[:, None]
... j = jnp.arange(x.shape[1])
... return (i + 1) % (j + 1 + k) == 0
>>> mask_func(jnp.ones((3, 3)))
Array([[ True, False, False],
[ True, True, False],
[ True, False, True]], dtype=bool)
>>> jnp.mask_indices(3, mask_func)
(Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 2], dtype=int32))
"""
i, j = nonzero(mask_func(ones((n, n)), k), size=size)
return (i, j)


def _triu_size(n, m, k):
Expand Down
11 changes: 11 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2424,6 +2424,17 @@ def testTrilIndicesFrom(self, shape, dtype, k):
args_maker = lambda: [rng(shape, dtype), k]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)

@jtu.sample_product(
n = [2, 3, 4],
k = [None, -1, 0, 1],
funcname = ['triu', 'tril']
)
def testMaskIndices(self, n, k, funcname):
kwds = {} if k is None else {'k': k}
jnp_result = jnp.mask_indices(n, getattr(jnp, funcname), **kwds)
np_result = np.mask_indices(n, getattr(np, funcname), **kwds)
self.assertArraysEqual(jnp_result, np_result, check_dtypes=False)

@jtu.sample_product(
dtype=default_dtypes,
a_shape=[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2), (1, 2), (0, 2), (2, 3), (2, 2, 2), (2, 2, 2, 2)],
Expand Down

0 comments on commit cf91e76

Please sign in to comment.