Skip to content

Commit

Permalink
BUG: Fix the signature for np.array_api.take
Browse files Browse the repository at this point in the history
The array_api take() doesn't flatten the array by default, so the axis
argument must be provided for multidimensional arrays. However, it should be
optional when the input array is 1-D, which the signature previously did not
allow.

c.f. data-apis/array-api#644
  • Loading branch information
asmeurer committed Jul 15, 2023
1 parent 308b348 commit 37ba69c
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions numpy/array_api/_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@

import numpy as np

def take(x: Array, indices: Array, /, *, axis: int) -> Array:
def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array:
"""
Array API compatible wrapper for :py:func:`np.take <numpy.take>`.
See its docstring for more information.
"""
"""
if axis is None and x.ndim != 1:
raise ValueError("axis must be specified when ndim > 1")
if indices.dtype not in _integer_dtypes:
raise TypeError("Only integer dtypes are allowed in indexing")
if indices.ndim != 1:
if indices.ndim != 1:
raise ValueError("Only 1-dim indices array is supported")
return Array._new(np.take(x._array, indices._array, axis=axis))

0 comments on commit 37ba69c

Please sign in to comment.