Skip to content

Commit

Permalink
Improve docs for jax.numpy: square, sqrt and modf
Browse files Browse the repository at this point in the history
  • Loading branch information
rajasekharporeddy committed Sep 23, 2024
1 parent c05706b commit e976dee
Showing 1 changed file with 94 additions and 3 deletions.
97 changes: 94 additions & 3 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,9 +636,36 @@ def tanh(x: ArrayLike, /) -> Array:
def arctanh(x: ArrayLike, /) -> Array:
return lax.atanh(*promote_args_inexact('arctanh', x))

@implements(np.sqrt, module='numpy')

@partial(jit, inline=True)
def sqrt(x: ArrayLike, /) -> Array:
"""Calculates element-wise non-negative square root of the input array.
JAX implementation of :obj:`numpy.sqrt`.
Args:
x: input array or scalar.
Returns:
An array containing the non-negative square root of the elements of ``x``.
Note:
- For real-valued negative inputs, ``jnp.sqrt`` produces a ``nan`` output.
- For complex-valued negative inputs, ``jnp.sqrt`` produces a ``complex`` output.
See also:
- :func:`jax.numpy.square`: Calculates the element-wise square of the input.
- :func:`jax.numpy.power`: Calculates the element-wise base ``x1`` exponential
of ``x2``.
Examples:
>>> x = jnp.array([-8-6j, 1j, 4])
>>> with jnp.printoptions(precision=3, suppress=True):
... jnp.sqrt(x)
Array([1. -3.j , 0.707+0.707j, 2. +0.j ], dtype=complex64)
>>> jnp.sqrt(-1)
Array(nan, dtype=float32, weak_type=True)
"""
return lax.sqrt(*promote_args_inexact('sqrt', x))

@implements(np.cbrt, module='numpy')
Expand Down Expand Up @@ -2162,9 +2189,50 @@ def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return lax.rem(*promote_args_numeric("fmod", x1, x2))


@implements(np.square, module='numpy')
@partial(jit, inline=True)
def square(x: ArrayLike, /) -> Array:
"""Calculate element-wise square of the input array.
JAX implementation of :obj:`numpy.square`.
Args:
x: input array or scalar.
Returns:
An array containing the square of the elements of ``x``.
Note:
``jnp.square`` is equivalent to computing ``jnp.power(x, 2)``.
See also:
- :func:`jax.numpy.sqrt`: Calculates the element-wise non-negative square root
of the input array.
- :func:`jax.numpy.power`: Calculates the element-wise base ``x1`` exponential
of ``x2``.
- :func:`jax.lax.integer_pow`: Computes element-wise power :math:`x^y`, where
:math:`y` is a fixed integer.
- :func:`jax.numpy.float_power`: Computes the first array raised to the power
of second array, element-wise, by promoting to the inexact dtype.
Examples:
>>> x = jnp.array([3, -2, 5.3, 1])
>>> jnp.square(x)
Array([ 9. , 4. , 28.090002, 1. ], dtype=float32)
>>> jnp.power(x, 2)
Array([ 9. , 4. , 28.090002, 1. ], dtype=float32)
For integer inputs:
>>> x1 = jnp.array([2, 4, 5, 6])
>>> jnp.square(x1)
Array([ 4, 16, 25, 36], dtype=int32)
For complex-valued inputs:
>>> x2 = jnp.array([1-3j, -1j, 2])
>>> jnp.square(x2)
Array([-8.-6.j, -1.+0.j, 4.+0.j], dtype=complex64)
"""
check_arraylike("square", x)
x, = promote_dtypes_numeric(x)
return lax.integer_pow(x, 2)
Expand Down Expand Up @@ -2343,9 +2411,32 @@ def real(val: ArrayLike, /) -> Array:
check_arraylike("real", val)
return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val)

@implements(np.modf, module='numpy', skip_params=['out'])

@jit
def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]:
"""Return element-wise fractional and integral parts of the input array.
JAX implementation of :obj:`numpy.modf`.
Args:
x: input array or scalar.
out: Not used by JAX.
Returns:
An array containing the fractional and integral parts of the elements of ``x``,
promoting dtypes inexact.
See also:
- :func:`jax.numpy.divmod`: Calculates the integer quotient and remainder of
``x1`` by ``x2`` element-wise.
Examples:
>>> jnp.modf(4.8)
(Array(0.8000002, dtype=float32, weak_type=True), Array(4., dtype=float32, weak_type=True))
>>> x = jnp.array([-3.4, -5.7, 0.6, 1.5, 2.3])
>>> jnp.modf(x)
(Array([-0.4000001 , -0.6999998 , 0.6 , 0.5 , 0.29999995], dtype=float32), Array([-3., -5., 0., 1., 2.], dtype=float32))
"""
check_arraylike("modf", x)
x, = promote_dtypes_inexact(x)
if out is not None:
Expand Down

0 comments on commit e976dee

Please sign in to comment.