diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 387b3b2a51a7..1d1b3512fd3e 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4388,9 +4388,85 @@ def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: return concatenate(arrs, axis=1) -@util.implements(np.choose, skip_params=['out']) -def choose(a: ArrayLike, choices: Sequence[ArrayLike], +def choose(a: ArrayLike, choices: Array | np.ndarray | Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: + """Construct an array by stacking slices of choice arrays. + + JAX implementation of :func:`numpy.choose`. + + The semantics of this function can be confusing, but in the simplest case where + ``a`` is a one-dimensional array, ``choices`` is a two-dimensional array, and + all entries of ``a`` are in-bounds (i.e. ``0 <= a_i < len(choices)``), then the + function is equivalent to the following:: + + def choose(a, choices): + return jnp.array([choices[a_i, i] for i, a_i in enumerate(a)]) + + In the more general case, ``a`` may have any number of dimensions and ``choices`` + may be an arbitrary sequence of broadcast-compatible arrays. In this case, again + for in-bound indices, the logic is equivalent to:: + + def choose(a, choices): + a, *choices = jnp.broadcast_arrays(a, *choices) + choices = jnp.array(choices) + return jnp.array([choices[a[idx], *idx] for idx in np.ndindex(a.shape)]) + + The only additional complexity comes from the ``mode`` argument, which controls + the behavior for out-of-bound indices in ``a`` as described below. + + Args: + a: an N-dimensional array of integer indices. + choices: an array or sequence of arrays. All arrays in the sequence must be + mutually broadcast compatible with ``a``. + out: unused by JAX + mode: specify the out-of-bounds indexing mode; one of ``'raise'`` (default), + ``'wrap'``, or ``'clip'``. Note that the default mode of ``'raise'`` is + not compatible with JAX transformations. + + Returns: + an array containing stacked slices from ``choices`` at the indices + specified by ``a``. The shape of the result is + ``broadcast_shapes(a.shape, *(c.shape for c in choices))``. + + See also: + - :func:`jax.lax.switch`: choose between N functions based on an index. + + Examples: + Here is the simplest case of a 1D index array with a 2D choice array, + in which case this chooses the indexed value from each column: + + >>> choices = jnp.array([[ 1, 2, 3, 4], + ... [ 5, 6, 7, 8], + ... [ 9, 10, 11, 12]]) + >>> a = jnp.array([2, 0, 1, 0]) + >>> jnp.choose(a, choices) + Array([9, 2, 7, 4], dtype=int32) + + The ``mode`` argument specifies what to do with out-of-bound indices; + options are to either ``wrap`` or ``clip``: + + >>> a2 = jnp.array([2, 0, 1, 4]) # last index out-of-bound + >>> jnp.choose(a2, choices, mode='clip') + Array([ 9, 2, 7, 12], dtype=int32) + >>> jnp.choose(a2, choices, mode='wrap') + Array([9, 2, 7, 8], dtype=int32) + + In the more general case, ``choices`` may be a sequence of array-like + objects with any broadcast-compatible shapes. + + >>> choice_1 = jnp.array([1, 2, 3, 4]) + >>> choice_2 = 99 + >>> choice_3 = jnp.array([[10], + ... [20], + ... [30]]) + >>> a = jnp.array([[0, 1, 2, 0], + ... [1, 2, 0, 1], + ... [2, 0, 1, 2]]) + >>> jnp.choose(a, [choice_1, choice_2, choice_3], mode='wrap') + Array([[ 1, 99, 10, 4], + [99, 20, 3, 99], + [30, 2, 99, 30]], dtype=int32) + """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.choose is not supported.") util.check_arraylike('choose', a, *choices) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index c23f659bd3f9..fd2524bc350d 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -218,7 +218,7 @@ def cbrt(x: ArrayLike, /) -> Array: ... cdouble: Any def ceil(x: ArrayLike, /) -> Array: ... character = _np.character -def choose(a: ArrayLike, choices: Sequence[ArrayLike], +def choose(a: ArrayLike, choices: Array | _np.ndarray | Sequence[ArrayLike], out: None = ..., mode: str = ...) -> Array: ... def clip( x: ArrayLike | None = ...,