Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC: improve docs of transpose & matrix_transpose #20957

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 112 additions & 11 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,78 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10,
view of the input.
"""

@util.implements(np.transpose, lax_description=_ARRAY_VIEW_DOC)
def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array:
"""Return a transposed version of an N-dimensional array.

JAX implementation of :func:`jax.numpy.transpose`, implemented in terms of
:func:`jax.lax.transpose`.

Args:
a: input array
axes: optionally specify the permutation using a length-`a.ndim` sequence of integers
``i`` satisfying ``0 <= i < a.ndim``. Defaults to ``range(a.ndim)[::-1]``, i.e
reverses the order of all axes.

Returns:
transposed copy of the array.

See Also:
- :func:`jax.Array.transpose`: equivalent function via an :class:`~jax.Array` method.
- :attr:`jax.Array.T`: equivalent function via an :class:`~jax.Array` property.
- :func:`jax.numpy.matrix_transpose`: transpose the last two axes of an array. This is
suitable for working with batched 2D matrices.
- :func:`jax.numpy.swapaxes`: swap any two axes in an array.
- :func:`jax.numpy.moveaxis`: move an axis to another postion in the array.

Note:
Unlike :func:`numpy.transpose`, :func:`jax.numpy.transpose` will return a copy rather
than a view of the input array. However, under JIT, the compiler will optimize-away
such copies when possible, so this doesn't have performance impacts in practice.

Examples:
For a 1D array, the transpose is the identity:

>>> x = jnp.array([1, 2, 3, 4])
>>> jnp.transpose(x)
Array([1, 2, 3, 4], dtype=int32)

For a 2D array, the transpose is a matrix transpose:

>>> x = jnp.array([[1, 2],
... [3, 4]])
>>> jnp.transpose(x)
Array([[1, 3],
[2, 4]], dtype=int32)

For an N-dimensional array, the transpose reverses the order of the axes:

>>> x = jnp.zeros(shape=(3, 4, 5))
>>> jnp.transpose(x).shape
(5, 4, 3)

The ``axes`` argument can be specified to change this default behavior:

>>> jnp.transpose(x, (0, 2, 1)).shape
(3, 5, 4)

Since swapping the last two axes is a common operation, it can be done
via its own API, :func:`jax.numpy.matrix_transpose`:

>>> jnp.matrix_transpose(x).shape
(3, 5, 4)

For convenience, transposes may also be performed using the :meth:`jax.Array.transpose`
method or the :attr:`jax.Array.T` property:

>>> x = jnp.array([[1, 2],
... [3, 4]])
>>> x.transpose()
Array([[1, 3],
[2, 4]], dtype=int32)
>>> x.T
Array([[1, 3],
[2, 4]], dtype=int32)
"""
util.check_arraylike("transpose", a)
axes_ = list(range(ndim(a))[::-1]) if axes is None else axes
axes_ = [_canonicalize_axis(i, ndim(a)) for i in axes_]
Expand All @@ -555,19 +625,50 @@ def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array:
return lax.transpose(a, axes)


@util.implements(getattr(np, 'matrix_transpose', None))
def matrix_transpose(x: ArrayLike, /) -> Array:
"""Transposes the last two dimensions of x.
"""Transpose the last two dimensions of an array.

JAX implementation of :func:`jax.numpy.matrix_transpose`, implemented in terms of
:func:`jax.lax.transpose`.

Parameters
----------
x : array_like
Input array. Must have ``x.ndim >= 2``.
Args:
x: input array, Must have ``x.ndim >= 2``

Returns:
matrix-transposed copy of the array.

Returns
-------
xT : Array
Transposed array.
See Also:
- :attr:`jax.Array.mT`: same operation accessed via an :func:`~jax.Array` property.
- :func:`jax.numpy.transpose`: general multi-axis transpose

Note:
Unlike :func:`numpy.matrix_transpose`, :func:`jax.numpy.matrix_transpose` will return a
copy rather than a view of the input array. However, under JIT, the compiler will
optimize-away such copies when possible, so this doesn't have performance impacts in practice.

Examples:
Here is a 2x2x2 matrix representing a batched 2x2 matrix:

>>> x = jnp.array([[[1, 2],
... [3, 4]],
... [[5, 6],
... [7, 8]]])
>>> jnp.matrix_transpose(x)
Array([[[1, 3],
[2, 4]],
<BLANKLINE>
[[5, 7],
[6, 8]]], dtype=int32)

For convenience, you can perform the same transpose via the :attr:`~jax.Array.mT`
property of :class:`jax.Array`:

>>> x.mT
Array([[[1, 3],
[2, 4]],
<BLANKLINE>
[[5, 7],
[6, 8]]], dtype=int32)
"""
util.check_arraylike("matrix_transpose", x)
ndim = np.ndim(x)
Expand Down
14 changes: 11 additions & 3 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6018,9 +6018,17 @@ def test_lax_numpy_docstrings(self):

# Functions that have their own docstrings & don't wrap numpy.
known_exceptions = {
'fromfile', 'fromiter', 'frompyfunc', 'vectorize',
'argwhere', 'where', 'nonzero', 'flatnonzero'}

'argwhere',
'flatnonzero',
'fromfile',
'fromiter',
'frompyfunc',
'matrix_transpose',
'nonzero',
'transpose',
'vectorize',
'where',
}
for name in dir(jnp):
if name in known_exceptions or name.startswith('_'):
continue
Expand Down