Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC: Improve docs for jax.numpy: square, sqrt and modf #23834

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 94 additions & 3 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,9 +636,36 @@ def tanh(x: ArrayLike, /) -> Array:
def arctanh(x: ArrayLike, /) -> Array:
return lax.atanh(*promote_args_inexact('arctanh', x))

@implements(np.sqrt, module='numpy')

@partial(jit, inline=True)
def sqrt(x: ArrayLike, /) -> Array:
"""Calculates element-wise non-negative square root of the input array.

JAX implementation of :obj:`numpy.sqrt`.

Args:
x: input array or scalar.

Returns:
An array containing the non-negative square root of the elements of ``x``.

Note:
- For real-valued negative inputs, ``jnp.sqrt`` produces a ``nan`` output.
- For complex-valued negative inputs, ``jnp.sqrt`` produces a ``complex`` output.

See also:
- :func:`jax.numpy.square`: Calculates the element-wise square of the input.
- :func:`jax.numpy.power`: Calculates the element-wise base ``x1`` exponential
of ``x2``.

Examples:
>>> x = jnp.array([-8-6j, 1j, 4])
>>> with jnp.printoptions(precision=3, suppress=True):
... jnp.sqrt(x)
Array([1. -3.j , 0.707+0.707j, 2. +0.j ], dtype=complex64)
>>> jnp.sqrt(-1)
Array(nan, dtype=float32, weak_type=True)
"""
return lax.sqrt(*promote_args_inexact('sqrt', x))

@implements(np.cbrt, module='numpy')
Expand Down Expand Up @@ -2162,9 +2189,50 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return lax.rem(*promote_args_numeric("fmod", x1, x2))


@implements(np.square, module='numpy')
@partial(jit, inline=True)
def square(x: ArrayLike, /) -> Array:
"""Calculate element-wise square of the input array.

JAX implementation of :obj:`numpy.square`.

Args:
x: input array or scalar.

Returns:
An array containing the square of the elements of ``x``.

Note:
``jnp.square`` is equivalent to computing ``jnp.power(x, 2)``.

See also:
- :func:`jax.numpy.sqrt`: Calculates the element-wise non-negative square root
of the input array.
- :func:`jax.numpy.power`: Calculates the element-wise base ``x1`` exponential
of ``x2``.
- :func:`jax.lax.integer_pow`: Computes element-wise power :math:`x^y`, where
:math:`y` is a fixed integer.
- :func:`jax.numpy.float_power`: Computes the first array raised to the power
of second array, element-wise, by promoting to the inexact dtype.

Examples:
>>> x = jnp.array([3, -2, 5.3, 1])
>>> jnp.square(x)
Array([ 9. , 4. , 28.090002, 1. ], dtype=float32)
>>> jnp.power(x, 2)
Array([ 9. , 4. , 28.090002, 1. ], dtype=float32)

For integer inputs:

>>> x1 = jnp.array([2, 4, 5, 6])
>>> jnp.square(x1)
Array([ 4, 16, 25, 36], dtype=int32)

For complex-valued inputs:

>>> x2 = jnp.array([1-3j, -1j, 2])
>>> jnp.square(x2)
Array([-8.-6.j, -1.+0.j, 4.+0.j], dtype=complex64)
"""
check_arraylike("square", x)
x, = promote_dtypes_numeric(x)
return lax.integer_pow(x, 2)
Expand Down Expand Up @@ -2343,9 +2411,32 @@ def real(val: ArrayLike, /) -> Array:
check_arraylike("real", val)
return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val)

@implements(np.modf, module='numpy', skip_params=['out'])

@jit
def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]:
"""Return element-wise fractional and integral parts of the input array.

JAX implementation of :obj:`numpy.modf`.

Args:
x: input array or scalar.
out: Not used by JAX.

Returns:
An array containing the fractional and integral parts of the elements of ``x``,
promoting dtypes inexact.

See also:
- :func:`jax.numpy.divmod`: Calculates the integer quotient and remainder of
``x1`` by ``x2`` element-wise.

Examples:
>>> jnp.modf(4.8)
(Array(0.8000002, dtype=float32, weak_type=True), Array(4., dtype=float32, weak_type=True))
>>> x = jnp.array([-3.4, -5.7, 0.6, 1.5, 2.3])
>>> jnp.modf(x)
(Array([-0.4000001 , -0.6999998 , 0.6 , 0.5 , 0.29999995], dtype=float32), Array([-3., -5., 0., 1., 2.], dtype=float32))
"""
check_arraylike("modf", x)
x, = promote_dtypes_inexact(x)
if out is not None:
Expand Down