From d51ccdf628c51d0eeae805536990ed5da489990e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 1 May 2024 17:36:24 -0700 Subject: [PATCH] DOC: Improve docstrings for jax.scipy.linalg --- jax/_src/scipy/linalg.py | 1185 +++++++++++++++++++++++++++++++++++--- 1 file changed, 1102 insertions(+), 83 deletions(-) diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index f5a3f75cae25..d064f28a5d81 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -17,7 +17,6 @@ from functools import partial import numpy as np -import scipy.linalg import textwrap from typing import overload, Any, Literal @@ -29,7 +28,7 @@ from jax._src.lax import linalg as lax_linalg from jax._src.lax import qdwh from jax._src.numpy.util import ( - check_arraylike, implements, promote_dtypes, promote_dtypes_inexact, + check_arraylike, promote_dtypes, promote_dtypes_inexact, promote_dtypes_complex) from jax._src.typing import Array, ArrayLike @@ -46,17 +45,111 @@ def _cholesky(a: ArrayLike, lower: bool) -> Array: l = lax_linalg.cholesky(a if lower else jnp.conj(a.mT), symmetrize_input=False) return l if lower else jnp.conj(l.mT) -@implements(scipy.linalg.cholesky, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) + def cholesky(a: ArrayLike, lower: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> Array: + """Compute the Cholesky decomposition of a matrix. + + JAX implementation of :func:`scipy.linalg.cholesky`. + + The Cholesky decomposition of a matrix `A` is: + + .. math:: + + A = U^HU = LL^H + + where `U` is an upper-triangular matrix and `L` is a lower-triangular matrix. + + Args: + a: input array, representing a (batched) positive-definite hermitian matrix. + Must have shape ``(..., N, N)``. + lower: if True, compute the lower Cholesky decomposition `L`. if False + (default), compute the upper Cholesky decomposition `U`. + overwrite_a: unused by JAX + check_finite: unused by JAX + + Returns: + array of shape ``(..., N, N)`` representing the cholesky decomposition + of the input. + + See Also: + - :func:`jax.numpy.linalg.cholesky`: NumPy-stype Cholesky API + - :func:`jax.lax.linalg.cholesky`: XLA-style Cholesky API + - :func:`jax.scipy.linalg.cho_factor` + - :func:`jax.scipy.linalg.cho_solve` + + Example: + A small real Hermitian positive-definite matrix: + + >>> x = jnp.array([[2., 1.], + ... [1., 2.]]) + + Upper Cholesky factorization: + + >>> jax.scipy.linalg.cholesky(x) + Array([[1.4142135 , 0.70710677], + [0. , 1.2247449 ]], dtype=float32) + + Lower Cholesky factorization: + + >>> jax.scipy.linalg.cholesky(x, lower=True) + Array([[1.4142135 , 0. ], + [0.70710677, 1.2247449 ]], dtype=float32) + + Reconstructing ``x`` from its factorization: + + >>> L = jax.scipy.linalg.cholesky(x, lower=True) + >>> jnp.allclose(x, L @ L.T) + Array(True, dtype=bool) + """ del overwrite_a, check_finite # Unused return _cholesky(a, lower) -@implements(scipy.linalg.cho_factor, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) + def cho_factor(a: ArrayLike, lower: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, bool]: + """Factorization for Cholesky-based linear solves + + JAX implementation of :func:`scipy.linalg.cho_factor`. This function returns + a result suitable for use with :func:`jax.scipy.linalg.cho_solve`. For direct + Cholesky decompositions, prefer :func:`jax.scipy.linalg.cholesky`. + + Args: + a: input array, representing a (batched) positive-definite hermitian matrix. + Must have shape ``(..., N, N)``. + lower: if True, compute the lower triangular Cholesky decomposition (default: False). + overwrite_a: unused by JAX + check_finite: unused by JAX + + Returns: + ``(c, lower)``: ``c`` is an array of shape ``(..., N, N)`` representing the lower or + upper cholesky decomposition of the input; ``lower`` is a boolean specifying whether + this is the lower or upper decomposition. + + See Also: + - :func:`jax.scipy.linalg.cholesky` + - :func:`jax.scipy.linalg.cho_solve` + + Example: + A small real Hermitian positive-definite matrix: + + >>> x = jnp.array([[2., 1.], + ... [1., 2.]]) + + Compute the cholesky factorization via :func:`~jax.scipy.linalg.cho_factor`, + and use it to solve a linear equation via :func:`~jax.scipy.linalg.cho_solve`. + + >>> b = jnp.array([3., 4.]) + >>> cfac = jax.scipy.linalg.cho_factor(x) + >>> y = jax.scipy.linalg.cho_solve(cfac, b) + >>> y + Array([0.6666666, 1.6666666], dtype=float32) + + Check that the result is consistent: + + >>> jnp.allclose(x @ y, b) + Array(True, dtype=bool) + """ del overwrite_a, check_finite # Unused return (cholesky(a, lower=lower), lower) @@ -70,10 +163,49 @@ def _cho_solve(c: ArrayLike, b: ArrayLike, lower: bool) -> Array: transpose_a=lower, conjugate_a=lower) return b -@implements(scipy.linalg.cho_solve, update_doc=False, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'check_finite')) + def cho_solve(c_and_lower: tuple[ArrayLike, bool], b: ArrayLike, overwrite_b: bool = False, check_finite: bool = True) -> Array: + """Solve a linear system using a Cholesky factorization + + JAX implementation of :func:`scipy.linalg.cho_solve`. Uses the output + of :func:`jax.scipy.linalg.cho_factor`. + + Args: + c_and_lower: ``(c, lower)``, where ``c`` is an array of shape ``(..., N, N)`` + representing the lower or upper cholesky decomposition of the matrix, and + ``lower`` is a boolean specifying whethe this is the lower or upper decomposition. + b: right-hand-side of linear system. Must have shape ``(..., N)`` + overwrite_a: unused by JAX + check_finite: unused by JAX + + Returns: + Array of shape ``(..., N)`` representing the solution of the linear system. + + See Also: + - :func:`jax.scipy.linalg.cholesky` + - :func:`jax.scipy.linalg.cho_factor` + + Example: + A small real Hermitian positive-definite matrix: + + >>> x = jnp.array([[2., 1.], + ... [1., 2.]]) + + Compute the cholesky factorization via :func:`~jax.scipy.linalg.cho_factor`, + and use it to solve a linear equation via :func:`~jax.scipy.linalg.cho_solve`. + + >>> b = jnp.array([3., 4.]) + >>> cfac = jax.scipy.linalg.cho_factor(x) + >>> y = jax.scipy.linalg.cho_solve(cfac, b) + >>> y + Array([0.6666666, 1.6666666], dtype=float32) + + Check that the result is consistent: + + >>> jnp.allclose(x @ y, b) + Array(True, dtype=bool) + """ del overwrite_b, check_finite # Unused c, lower = c_and_lower return _cho_solve(c, b, lower) @@ -112,17 +244,112 @@ def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') -> Array | tuple[Array, Array, Array]: ... -@implements(scipy.linalg.svd, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite', 'lapack_driver')) + def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') -> Array | tuple[Array, Array, Array]: + r"""Compute the singular value decomposition. + + JAX implementation of :func:`scipy.linalg.svd`. + + The SVD of a matrix `A` is given by + + .. math:: + + A = U\Sigma V^H + + - :math:`U` contains the left singular vectors and satisfies :math:`U^HU=I` + - :math:`V` contains the right singular vectors and satisfies :math:`V^HV=I` + - :math:`\Sigma` is a diagonal matrix of singular values. + + Args: + a: input array, of shape ``(..., N, M)`` + full_matrices: if True (default) compute the full matrices; i.e. ``u`` and ``vh`` have + shape ``(..., N, N)`` and ``(..., M, M)``. If False, then the shapes are + ``(..., N, K)`` and ``(..., K, M)`` with ``K = min(N, M)``. + compute_uv: if True (default), return the full SVD ``(u, s, vh)``. If False then return + only the singular values ``s``. + overwrite_a: unused by JAX + check_finite: unused by JAX + lapack_driver: unused by JAX + + Returns: + A tuple of arrays ``(u, s, vh)`` if ``compute_uv`` is True, otherwise the array ``s``. + + - ``u``: left singular vectors of shape ``(..., N, N)`` if ``full_matrices`` is True + or ``(..., N, K)`` otherwise. + - ``s``: singular values of shape ``(..., K)`` + - ``vh``: conjugate-transposed right singular vectors of shape ``(..., M, M)`` + if ``full_matrices`` is True or ``(..., K, M)`` otherwise. + + where ``K = min(N, M)``. + + See also: + - :func:`jax.numpy.linalg.svd`: NumPy-style SVD API + - :func:`jax.lax.linalg.svd`: XLA-style SVD API + + Example: + Consider the SVD of a small real-valued array: + + >>> x = jnp.array([[1., 2., 3.], + ... [6., 5., 4.]]) + >>> u, s, vt = jax.scipy.linalg.svd(x, full_matrices=False) + >>> s # doctest: +SKIP + Array([9.361919 , 1.8315067], dtype=float32) + + The singular vectors are in the columns of ``u`` and ``v = vt.T``. These vectors are + orthonormal, which can be demonstrated by comparing the matrix product with the + identity matrix: + + >>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5) + Array(True, dtype=bool) + >>> v = vt.T + >>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5) + Array(True, dtype=bool) + + Given the SVD, ``x`` can be reconstructed via matrix multiplication: + + >>> x_reconstructed = u @ jnp.diag(s) @ vt + >>> jnp.allclose(x_reconstructed, x) + Array(True, dtype=bool) + """ del overwrite_a, check_finite, lapack_driver # unused return _svd(a, full_matrices=full_matrices, compute_uv=compute_uv) -@implements(scipy.linalg.det, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) + def det(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array: + """Compute the determinant of a matrix + + JAX implementation of :func:`scipy.linalg.det`. + + Args: + a: input array, of shape ``(..., N, N)`` + overwrite_a: unused by JAX + check_finite: unused by JAX + + Returns + Determinant of shape ``a.shape[:-2]`` + + See Also: + :func:`jax.numpy.linalg.det`: NumPy-style determinant API + + Examples: + Determinant of a small 2D array: + + >>> x = jnp.array([[1., 2.], + ... [3., 4.]]) + >>> jax.scipy.linalg.det(x) + Array(-2., dtype=float32) + + Batch-wise determinant of multiple 2D arrays: + + >>> x = jnp.array([[[1., 2.], + ... [3., 4.]], + ... [[8., 5.], + ... [7., 9.]]]) + >>> jax.scipy.linalg.det(x) + Array([-2., 37.], dtype=float32) + """ del overwrite_a, check_finite # unused return jnp.linalg.det(a) @@ -182,13 +409,70 @@ def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) -> Array | tuple[Array, Array]: ... -@implements(scipy.linalg.eigh, - lax_description=_no_overwrite_and_chkfinite_doc, - skip_params=('overwrite_a', 'overwrite_b', 'turbo', 'check_finite')) def eigh(a: ArrayLike, b: ArrayLike | None = None, lower: bool = True, eigvals_only: bool = False, overwrite_a: bool = False, overwrite_b: bool = False, turbo: bool = True, eigvals: None = None, type: int = 1, check_finite: bool = True) -> Array | tuple[Array, Array]: + """Compute eigenvalues and eigenvectors for a Hermitian matrix + + JAX implementation of :func:`jax.scipy.linalg.eigh`. + + Args: + a: Hermitian input array of shape ``(..., N, N)`` + b: optional Hermitian input of shape ``(..., N, N)``. If specified, compute + the generalized eigenvalue problem. + lower: if True (default) access only the lower portion of the input matrix. + Otherwise access only the upper portion. + eigvals_only: If True, compute only the eigenvalues. If False (default) compute + both eigenvalues and eigenvectors. + type: if ``b`` is specified, ``type`` gives the type of generalized eigenvalue + problem to be computed. Denoting ``(λ, v)`` as an eigenvalue, eigenvector pair: + + - ``type = 1`` solves ``a @ v = λ * b @ v`` (default) + - ``type = 2`` solves ``a @ b @ v = λ * v`` + - ``type = 3`` solves ``b @ a @ v = λ * v`` + + eigvals: a ``(low, high)`` tuple specifying which eigenvalues to compute. + overwrite_a: unused by JAX. + overwrite_b: unused by JAX. + turbo: unused by JAX. + check_finite: unused by JAX. + + Returns: + A tuple of arrays ``(eigvals, eigvecs)`` if ``eigvals_only`` is False, otherwise + an array ``eigvals``. + + - ``eigvals``: array of shape ``(..., N)`` containing the eigenvalues. + - ``eigvecs``: array of shape ``(..., N, N)`` containing the eigenvectors. + + See also: + - :func:`jax.numpy.linalg.eigh`: NumPy-style eigh API. + - :func:`jax.lax.linalg.eigh`: XLA-style eigh API. + - :func:`jax.numpy.linalg.eig`: non-hermitian eigenvalue problem. + - :func:`jax.scipy.linalg.eigh_tridiagonal`: tri-diagonal eigenvalue problem. + + Examples: + Compute the standard eigenvalue decomposition of a simple 2x2 matrix: + + >>> a = jnp.array([[2., 1.], + ... [1., 2.]]) + >>> eigvals, eigvecs = jax.scipy.linalg.eigh(a) + >>> eigvals + Array([1., 3.], dtype=float32) + >>> eigvecs + Array([[-0.70710677, 0.70710677], + [ 0.70710677, 0.70710677]], dtype=float32) + + Eigenvectors are orthonormal: + + >>> jnp.allclose(eigvecs.T @ eigvecs, jnp.eye(2), atol=1E-5) + Array(True, dtype=bool) + + Solution satisfies the eigenvalue problem: + + >>> jnp.allclose(a @ eigvecs, eigvecs @ jnp.diag(eigvals)) + Array(True, dtype=bool) + """ del overwrite_a, overwrite_b, turbo, check_finite # unused return _eigh(a, b, lower, eigvals_only, eigvals, type) @@ -198,35 +482,230 @@ def _schur(a: Array, output: str) -> tuple[Array, Array]: a = a.astype(dtypes.to_complex_dtype(a.dtype)) return lax_linalg.schur(a) -@implements(scipy.linalg.schur) def schur(a: ArrayLike, output: str = 'real') -> tuple[Array, Array]: + """Compute the Schur decomposition + + JAX implementation of :func:`scipy.linalg.schur`. + + The Schur form `T` of a matrix `A` satisfies: + + .. math:: + + A = Z T Z^H + + where `Z` is unitary, and `T` is upper-triangular for the complex-valued Schur + decomposition (i.e. ``output="complex"``) and is quasi-upper-triangular for the + real-valued Schur decomposition (i.e. ``output="real"``). In the quasi-triangular + case, the diagonal may include 2x2 blocks associated with complex-valued + eigenvalue pairs of `A`. + + Args: + a: input array of shape ``(..., N, N)`` + output: Specify whether to compute the ``"real"`` (default) or ``"complex"`` + Schur decomposition. + + Returns: + A tuple of arrays ``(T, Z)`` + + - ``T`` is a shape ``(..., N, N)`` array containing the upper-triangular + Schur form of the input. + - ``Z`` is a shape ``(..., N, N)`` array containing the unitary Schur + transformation matrix. + + See also: + - :func:`jax.scipy.linalg.rsf2csf`: conver real Schur form to complex Schur form. + - :func:`jax.lax.linalg.schur`: XLA-style API for Schur decomposition. + + Example: + A Schur decomposition of a 3x3 matrix: + + >>> a = jnp.array([[1., 2., 3.], + ... [1., 4., 2.], + ... [3., 2., 1.]]) + >>> T, Z = jax.scipy.linalg.schur(a) + + The Schur form ``T`` is quasi-upper-triangular in general, but is truly + upper-triangular in this case because the input matrix is symmetric: + + >>> T # doctest: +SKIP + Array([[-2.0000005 , 0.5066295 , -0.43360388], + [ 0. , 1.5505103 , 0.74519426], + [ 0. , 0. , 6.449491 ]], dtype=float32) + + The transformation matrix ``Z`` is unitary: + + >>> jnp.allclose(Z.T @ Z, jnp.eye(3), atol=1E-5) + Array(True, dtype=bool) + + The input can be reconstructed from the outputs: + + >>> jnp.allclose(Z @ T @ Z.T, a) + Array(True, dtype=bool) + """ if output not in ('real', 'complex'): raise ValueError( f"Expected 'output' to be either 'real' or 'complex', got {output=}.") return _schur(a, output) -@implements(scipy.linalg.inv, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) + def inv(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array: + """Return the inverse of a square matrix + + JAX implementation of :func:`scipy.linalg.inv`. + + Args: + a: array of shape ``(..., N, N)`` specifying square array(s) to be inverted. + overwrite_a: unused in JAX + check_finite: unused in JAX + + Returns: + Array of shape ``(..., N, N)`` containing the inverse of the input. + + Notes: + In most cases, explicitly computing the inverse of a matrix is ill-advised. For + example, to compute ``x = inv(A) @ b``, it is more performant and numerically + precise to use a direct solve, such as :func:`jax.scipy.linalg.solve`. + + See Also: + - :func:`jax.numpy.linalg.inv`: NumPy-style API for matrix inverse + - :func:`jax.scipy.linalg.solve`: direct linear solver + + Example: + Compute the inverse of a 3x3 matrix + + >>> a = jnp.array([[1., 2., 3.], + ... [2., 4., 2.], + ... [3., 2., 1.]]) + >>> a_inv = jax.scipy.linalg.inv(a) + >>> a_inv # doctest: +SKIP + Array([[ 0. , -0.25 , 0.5 ], + [-0.25 , 0.5 , -0.25000003], + [ 0.5 , -0.25 , 0. ]], dtype=float32) + + Check that multiplying with the inverse gives the identity: + + >>> jnp.allclose(a @ a_inv, jnp.eye(3), atol=1E-5) + Array(True, dtype=bool) + + Multiply the inverse by a vector ``b``, to find a solution to ``a @ x = b`` + + >>> b = jnp.array([1., 4., 2.]) + >>> a_inv @ b + Array([ 0. , 1.25, -0.5 ], dtype=float32) + + Note, however, that explicitly computing the inverse in such a case can lead + to poor performance and loss of precision as the size of the problem grows. + Instead, you should use a direct solver like :func:`jax.scipy.linalg.solve`: + + >>> jax.scipy.linalg.solve(a, b) + Array([ 0. , 1.25, -0.5 ], dtype=float32) + """ del overwrite_a, check_finite # unused return jnp.linalg.inv(a) -@implements(scipy.linalg.lu_factor, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) @partial(jit, static_argnames=('overwrite_a', 'check_finite')) def lu_factor(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array]: + """Factorization for LU-based linear solves + + JAX implementation of :func:`scipy.linalg.lu_factor`. + + This function returns a result suitable for use with :func:`jax.scipy.linalg.lu_solve`. + For direct LU decompositions, prefer :func:`jax.scipy.linalg.lu`. + + Args: + a: input array of shape ``(..., M, N)``. + overwrite_a: unused by JAX + check_finite: unused by JAX + + Returns: + A tuple ``(lu, piv)`` + + - ``lu`` is an array of shape ``(..., M, N)``, containing ``L`` in its + lower triangle and ``U`` in its upper. + - ``piv`` is an array of shape ``(..., K)`` with ``K = min(M, N)``, + which encodes the pivots. + + See Also: + - :func:`jax.scipy.linalg.lu` + - :func:`jax.scipy.linalg.lu_solve` + + Example: + Solving a small linear system via LU factorization: + + >>> a = jnp.array([[2., 1.], + ... [1., 2.]]) + + Compute the lu factorization via :func:`~jax.scipy.linalg.lu_factor`, + and use it to solve a linear equation via :func:`~jax.scipy.linalg.lu_solve`. + + >>> b = jnp.array([3., 4.]) + >>> lufac = jax.scipy.linalg.lu_factor(a) + >>> y = jax.scipy.linalg.lu_solve(lufac, b) + >>> y + Array([0.6666666, 1.6666667], dtype=float32) + + Check that the result is consistent: + + >>> jnp.allclose(a @ y, b) + Array(True, dtype=bool) + """ del overwrite_a, check_finite # unused a, = promote_dtypes_inexact(jnp.asarray(a)) lu, pivots, _ = lax_linalg.lu(a) return lu, pivots -@implements(scipy.linalg.lu_solve, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'check_finite')) @partial(jit, static_argnames=('trans', 'overwrite_b', 'check_finite')) def lu_solve(lu_and_piv: tuple[Array, ArrayLike], b: ArrayLike, trans: int = 0, overwrite_b: bool = False, check_finite: bool = True) -> Array: + """Solve a linear system using an LU factorization + + JAX implementation of :func:`scipy.linalg.lu_solve`. Uses the output + of :func:`jax.scipy.linalg.lu_factor`. + + Args: + lu_and_piv: ``(lu, piv)``, output of :func:`~jax.scipy.linalg.lu_factor`. + ``lu`` is an array of shape ``(..., M, N)``, containing ``L`` in its lower + triangle and ``U`` in its upper. ``piv`` is an array of shape ``(..., K)``, + with ``K = min(M, N)``, which encodes the pivots. + b: right-hand-side of linear system. Must have shape ``(..., M)`` + trans: type of system to solve. Options are: + + - ``0``: :math:`A x = b` + - ``1``: :math:`A^Tx = b` + - ``2``: :math:`A^Hx = b` + + overwrite_b: unused by JAX + check_finite: unused by JAX + + Returns: + Array of shape ``(..., N)`` representing the solution of the linear system. + + See Also: + - :func:`jax.scipy.linalg.lu` + - :func:`jax.scipy.linalg.lu_factor` + + Example: + Solving a small linear system via LU factorization: + + >>> a = jnp.array([[2., 1.], + ... [1., 2.]]) + + Compute the lu factorization via :func:`~jax.scipy.linalg.lu_factor`, + and use it to solve a linear equation via :func:`~jax.scipy.linalg.lu_solve`. + + >>> b = jnp.array([3., 4.]) + >>> lufac = jax.scipy.linalg.lu_factor(a) + >>> y = jax.scipy.linalg.lu_solve(lufac, b) + >>> y + Array([0.6666666, 1.6666667], dtype=float32) + + Check that the result is consistent: + + >>> jnp.allclose(a @ y, b) + Array(True, dtype=bool) + """ del overwrite_b, check_finite # unused lu, pivots = lu_and_piv m, _ = lu.shape[-2:] @@ -269,11 +748,75 @@ def lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False, def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array] | tuple[Array, Array, Array]: ... -@implements(scipy.linalg.lu, update_doc=False, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite')) + @partial(jit, static_argnames=('permute_l', 'overwrite_a', 'check_finite')) def lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array] | tuple[Array, Array, Array]: + """Compute the LU decomposition + + JAX implementation of :func:`scipy.linalg.lu`. + + The LU decomposition of a matrix `A` is: + + .. math:: + + A = P L U + + where `P` is a permutation matrix, `L` is lower-triangular and `U` is upper-triangular. + + Args: + a: array of shape ``(..., M, N)`` to decompose. + permute_l: if True, then permute ``L`` and return ``(P @ L, U)`` (default: False) + overwrite_a: not used by JAX + check_finite: not used by JAX + + Returns: + A tuple of arrays ``(P @ L, U)`` if ``permute_l`` is True, else ``(P, L, U)``: + + - ``P`` is a permutation matrix of shape ``(..., M, M)`` + - ``L`` is a lower-triangular matrix of shape ``(... M, K)`` + - ``U`` is an upper-triangular matrix of shape ``(..., K, N)`` + + with ``K = min(M, N)`` + + See also: + - :func:`jax.numpy.linalg.lu`: NumPy-style API for LU decomposition. + - :func:`jax.lax.linalg.lu`: XLA-style API for LU decomposition. + - :func:`jax.scipy.linalg.lu_solve`: LU-based linear solver. + + Example: + An LU decomposition of a 3x3 matrix: + + >>> a = jnp.array([[1., 2., 3.], + ... [5., 4., 2.], + ... [3., 2., 1.]]) + >>> P, L, U = jax.scipy.linalg.lu(a) + + ``P`` is a permutation matrix: i.e. each row and column has a single ``1``: + + >>> P + Array([[0., 1., 0.], + [1., 0., 0.], + [0., 0., 1.]], dtype=float32) + + ``L`` and ``U`` are lower-triangular and upper-triangular matrices: + + >>> with jnp.printoptions(precision=3): + ... print(L) + ... print(U) + [[ 1. 0. 0. ] + [ 0.2 1. 0. ] + [ 0.6 -0.333 1. ]] + [[5. 4. 2. ] + [0. 1.2 2.6 ] + [0. 0. 0.667]] + + The original matrix can be reconstructed by multiplying the three together: + + >>> a_reconstructed = P @ L @ U + >>> jnp.allclose(a, a_reconstructed) + Array(True, dtype=bool) + """ del overwrite_a, check_finite # unused return _lu(a, permute_l) @@ -320,10 +863,77 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, *, mode: Lit def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full", pivoting: bool = False, check_finite: bool = True) -> tuple[Array] | tuple[Array, Array]: ... -@implements(scipy.linalg.qr, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite', 'lwork')) + def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "full", pivoting: bool = False, check_finite: bool = True) -> tuple[Array] | tuple[Array, Array]: + """Compute the QR decomposition of an array + + JAX implementation of :func:`scipy.linalg.qr`. + + The QR decomposition of a matrix `A` is given by + + .. math:: + + A = QR + + Where `Q` is a unitary matrix (i.e. :math:`Q^HQ=I`) and `R` is an upper-triangular + matrix. + + Args: + a: array of shape (..., M, N) + mode: Computational mode. Supported values are: + + - ``"full"`` (default): return `Q` of shape ``(M, M)`` and `R` of shape ``(M, N)``. + - ``"r"``: return only `R` + - ``"economic"``: return `Q` of shape ``(M, K)`` and `R` of shape ``(K, N)``, + where K = min(M, N). + + pivoting: Not implemened in JAX. + overwrite_a: unused in JAX + lwork: unused in JAX + check_finite: unused in JAX + + Returns: + A tuple ``(Q, R)`` (if ``mode`` is not ``"r"``) otherwise an array ``R``, + where: + + - ``Q`` is an orthogonal matrix of shape ``(..., M, M)`` (if ``mode`` is ``"full"``) + or ``(..., M, K)`` (if ``mode`` is ``"economic"``). + - ``R`` is an upper-triangular matrix of shape ``(..., M, N)`` (if ``mode`` is + ``"r"`` or ``"full"``) or ``(..., K, N)`` (if ``mode`` is ``"economic"``) + + with ``K = min(M, N)``. + + See also: + - :func:`jax.numpy.linalg.qr`: NumPy-style QR decomposition API + - :func:`jax.lax.linalg.qr`: XLA-style QR decompositon API + + Examples: + Compute the QR decomposition of a matrix: + + >>> a = jnp.array([[1., 2., 3., 4.], + ... [5., 4., 2., 1.], + ... [6., 3., 1., 5.]]) + >>> Q, R = jax.scipy.linalg.qr(a) + >>> Q # doctest: +SKIP + Array([[-0.12700021, -0.7581426 , -0.6396022 ], + [-0.63500065, -0.43322435, 0.63960224], + [-0.7620008 , 0.48737738, -0.42640156]], dtype=float32) + >>> R # doctest: +SKIP + Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ], + [ 0. , -1.7870499, -2.6534991, -1.028908 ], + [ 0. , 0. , -1.0660033, -4.050814 ]], dtype=float32) + + Check that ``Q`` is orthonormal: + + >>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5) + Array(True, dtype=bool) + + Reconstruct the input: + + >>> jnp.allclose(Q @ R, a) + Array(True, dtype=bool) + """ del overwrite_a, lwork, check_finite # unused return _qr(a, mode, pivoting) @@ -352,12 +962,59 @@ def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array: return vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b) -@implements(scipy.linalg.solve, - lax_description=_no_overwrite_and_chkfinite_doc, - skip_params=('overwrite_a', 'overwrite_b', 'debug', 'check_finite')) def solve(a: ArrayLike, b: ArrayLike, lower: bool = False, overwrite_a: bool = False, overwrite_b: bool = False, debug: bool = False, check_finite: bool = True, assume_a: str = 'gen') -> Array: + """Solve a linear system of equations + + JAX implementation of :func:`scipy.linalg.solve`. + + This solves a (batched) linear system of equations ``a @ x = b`` for ``x`` + given ``a`` and ``b``. + + Args: + a: array of shape ``(..., N, N)``. + b: array of shape ``(..., N)`` or ``(..., N, M)`` + lower: Referenced only if ``assume_a != 'gen'``. If True, only use the lower + triangle of the input, If False (default), only use the upper triangle. + assume_a: specify what properties of ``a`` can be assumed. Options are: + + - ``"gen"``: generic matrix (default) + - ``"sym"``: symmetric matrix + - ``"her"``: hermitian matrix + - ``"pos"``: positive-definite matrix + + overwrite_a: unused by JAX + overwrite_b: unused by JAX + debug: unused by JAX + check_finite: unused by JAX + + Returns: + An array of the same shape as ``b`` containing the solution to the linear system. + + See also: + - :func:`jax.scipy.linalg.lu_solve`: Solve via LU factorization. + - :func:`jax.scipy.linalg.cho_solve`: Solve via Cholesky factorization. + - :func:`jax.scipy.linalg.solve_triangular`: Solve a triangular system. + - :func:`jax.numpy.linalg.solve`: NumPy-style API for solving linear systems. + - :func:`jax.lax.custom_linear_solve`: matrix-free linear solver. + + Example: + A simple 3x3 linear system: + + >>> A = jnp.array([[1., 2., 3.], + ... [2., 4., 2.], + ... [3., 2., 1.]]) + >>> b = jnp.array([14., 16., 10.]) + >>> x = jax.scipy.linalg.solve(A, b) + >>> x + Array([1., 2., 3.], dtype=float32) + + Confirming that the result solves the system: + + >>> jnp.allclose(A @ x, b) + Array(True, dtype=bool) + """ del overwrite_a, overwrite_b, debug, check_finite #unused valid_assume_a = ['gen', 'sym', 'her', 'pos'] if assume_a not in valid_assume_a: @@ -391,32 +1048,96 @@ def _solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str, else: return out -@implements(scipy.linalg.solve_triangular, - lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_b', 'debug', 'check_finite')) + def solve_triangular(a: ArrayLike, b: ArrayLike, trans: int | str = 0, lower: bool = False, unit_diagonal: bool = False, overwrite_b: bool = False, debug: Any = None, check_finite: bool = True) -> Array: - del overwrite_b, debug, check_finite # unused - return _solve_triangular(a, b, trans, lower, unit_diagonal) + """Solve a triangular linear system of equations -_expm_description = textwrap.dedent(""" -In addition to the original NumPy argument(s) listed below, -also supports the optional boolean argument ``upper_triangular`` -to specify whether the ``A`` matrix is upper triangular, and the optional -argument ``max_squarings`` to specify the max number of squarings allowed -in the scaling-and-squaring approximation method. Return nan if the actual -number of squarings required is more than ``max_squarings``. + JAX implementation of :func:`scipy.linalg.solve_triangular`. -The number of required squarings = max(0, ceil(log2(norm(A)) - c) -where norm() denotes the L1 norm, and + This solves a (batched) linear system of equations ``a @ x = b`` for ``x`` + given a triangular matrix ``a`` and a vector or matrix ``b``. + + Args: + a: array of shape ``(..., N, N)``. Only part of the array will be accessed, + depending on the ``lower`` and ``unit_diagonal`` arguments. + b: array of shape ``(..., N)`` or ``(..., N, M)`` + lower: If True, only use the lower triangle of the input, If False (default), + only use the upper triangle. + unit_diagonal: If True, ignore diagonal elements of ``a`` and assume they are + ``1`` (default: False). + trans: specify what properties of ``a`` can be assumed. Options are: + + - ``0`` or ``'N'``: solve :math:`Ax=b` + - ``1`` or ``'T'``: solve :math:`A^Tx=b` + - ``2`` or ``'C'``: solve :math:`A^Hx=b` + + overwrite_b: unused by JAX + debug: unused by JAX + check_finite: unused by JAX + + Returns: + An array of the same shape as ``b`` containing the solution to the linear system. + + See also: + :func:`jax.scipy.linalg.solve`: Solve a general linear system. + + Example: + A simple 3x3 triangular linear system: + + >>> A = jnp.array([[1., 2., 3.], + ... [0., 3., 2.], + ... [0., 0., 5.]]) + >>> b = jnp.array([10., 8., 5.]) + >>> x = jax.scipy.linalg.solve_triangular(A, b) + >>> x + Array([3., 2., 1.], dtype=float32) + + Confirming that the result solves the system: + + >>> jnp.allclose(A @ x, b) + Array(True, dtype=bool) + + Computing the transposed problem: + + >>> x = jax.scipy.linalg.solve_triangular(A, b, trans='T') + >>> x + Array([10. , -4. , -3.4], dtype=float32) + + Confiriming that the result solves the system: + + >>> jnp.allclose(A.T @ x, b) + Array(True, dtype=bool) + """ + del overwrite_b, debug, check_finite # unused + return _solve_triangular(a, b, trans, lower, unit_diagonal) -- c=2.42 for float64 or complex128, -- c=1.97 for float32 or complex64 -""") -@implements(scipy.linalg.expm, lax_description=_expm_description) @partial(jit, static_argnames=('upper_triangular', 'max_squarings')) def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 16) -> Array: + """Compute the matrix exponential + + JAX implementation of :func:`scipy.linalg.expm`. + + Args: + A: array of shape ``(..., N, N)`` + upper_triangular: if True, then assume that ``A`` is upper-triangular. Default=False. + max_squarings: The number of squarings in the scaling-and-squaring approximation method + (default: 16). + + Returns: + An array of shape ``(..., N, N)`` containing the matrix exponent of ``A``. + + Notes: + This uses the scaling-and-squaring approximation method, with computational complexity + controlled by the optional ``max_squarings`` argument. Theoretically, the number of + required squarings is ``max(0, ceil(log2(norm(A))) - c)`` where ``norm(A)`` is the L1 + norm and ``c=2.42`` for float64/complex128, or ``c=1.97`` for float32/complex64. + + See Also: + :func:`jax.scipy.linalg.expm_frechet` + """ A, = promote_dtypes_inexact(A) if A.ndim < 2 or A.shape[-1] != A.shape[-2]: @@ -554,12 +1275,6 @@ def _pade13(A: Array) -> tuple[Array, Array]: return U,V -_expm_frechet_description = textwrap.dedent(""" -Does not currently support the Scipy argument ``jax.numpy.asarray_chkfinite``, -because `jax.numpy.asarray_chkfinite` does not exist at the moment. Does not -support the ``method='blockEnlarge'`` argument. -""") - @overload def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[True] = True) -> tuple[Array, Array]: ... @@ -572,34 +1287,89 @@ def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: bool = True) -> Array | tuple[Array, Array]: ... -@implements(scipy.linalg.expm_frechet, lax_description=_expm_frechet_description) + @partial(jit, static_argnames=('method', 'compute_expm')) def expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: bool = True) -> Array | tuple[Array, Array]: + """Compute the Frechet derivative of the matrix exponential. + + JAX implementation of :func:`scipy.linalg.expm_frechet` + + Args: + A: array of shape ``(..., N, N)`` + E: array of shape ``(..., N, N)``; specifies the direction of the derivative. + compute_expm: if True (default) then compute and return ``expm(A)``. + method: ignored by JAX + + Returns: + A tuple ``(expm_A, expm_frechet_AE)`` if ``compute_expm`` is True, else + the array ``expm_frechet_AE``. Both returned arrays have shape ``(..., N, N)``. + + See also: + :func:`jax.scipy.linalg.expm` + + Examples: + We can use this API to compute the matrix exponential of ``A``, as well as its + derivative in the direction ``E``: + + >>> key1, key2 = jax.random.split(jax.random.key(3372)) + >>> A = jax.random.normal(key1, (3, 3)) + >>> E = jax.random.normal(key2, (3, 3)) + >>> expmA, expm_frechet_AE = jax.scipy.linalg.expm_frechet(A, E) + + This can be equivalently computed using JAX's automatic differentiation methods; + here we'll compute the derivative of :func:`~jax.scipy.linalg.expm` in the + direction of ``E`` using :func:`jax.jvp`, and find the same results: + + >>> expmA2, expm_frechet_AE2 = jax.jvp(jax.scipy.linalg.expm, (A,), (E,)) + >>> jnp.allclose(expmA, expmA2) + Array(True, dtype=bool) + >>> jnp.allclose(expm_frechet_AE, expm_frechet_AE2) + Array(True, dtype=bool) + """ + del method # unused A_arr = jnp.asarray(A) E_arr = jnp.asarray(E) - if A_arr.ndim != 2 or A_arr.shape[0] != A_arr.shape[1]: - raise ValueError('expected A to be a square matrix') - if E_arr.ndim != 2 or E_arr.shape[0] != E_arr.shape[1]: - raise ValueError('expected E to be a square matrix') + if A_arr.ndim < 2 or A_arr.shape[-2] != A_arr.shape[1]: + raise ValueError(f'expected A to be a (batched) square matrix, got A.shape={A_arr.shape}') + if E_arr.ndim < 2 or E_arr.shape[-2] != E_arr.shape[-1]: + raise ValueError(f'expected E to be a (batched) square matrix, got E.shape={E_arr.shape}') if A_arr.shape != E_arr.shape: - raise ValueError('expected A and E to be the same shape') - if method is None: - method = 'SPS' - if method == 'SPS': - bound_fun = partial(expm, upper_triangular=False, max_squarings=16) - expm_A, expm_frechet_AE = jvp(bound_fun, (A_arr,), (E_arr,)) - else: - raise ValueError('only method=\'SPS\' is supported') + raise ValueError('expected A and E to be the same shape, got ' + f'A.shape={A_arr.shape} E.shape={E_arr.shape}') + bound_fun = partial(expm, upper_triangular=False, max_squarings=16) + expm_A, expm_frechet_AE = jvp(bound_fun, (A_arr,), (E_arr,)) if compute_expm: return expm_A, expm_frechet_AE else: return expm_frechet_AE -@implements(scipy.linalg.block_diag) @jit def block_diag(*arrs: ArrayLike) -> Array: + """Create a block diagonal matrix from input arrays. + + JAX implementation of :func:`scipy.linalg.block_diag`. + + Args: + *arrs: arrays of at most two dimensions + + Returns: + 2D block-diagonal array constructed by placing the input arrays + along the diagonal. + + Example: + >>> A = jnp.ones((1, 1)) + >>> B = jnp.ones((2, 2)) + >>> C = jnp.ones((3, 3)) + >>> jax.scipy.linalg.block_diag(A, B, C) + Array([[1., 0., 0., 0., 0., 0.], + [0., 1., 1., 0., 0., 0.], + [0., 1., 1., 0., 0., 0.], + [0., 0., 0., 1., 1., 1.], + [0., 0., 0., 1., 1., 1.], + [0., 0., 0., 1., 1., 1.]], dtype=float32) + """ if len(arrs) == 0: arrs = (jnp.zeros((1, 0)),) arrs = tuple(promote_dtypes(*arrs)) @@ -619,11 +1389,54 @@ def block_diag(*arrs: ArrayLike) -> Array: return acc -@implements(scipy.linalg.eigh_tridiagonal) @partial(jit, static_argnames=("eigvals_only", "select", "select_range")) def eigh_tridiagonal(d: ArrayLike, e: ArrayLike, *, eigvals_only: bool = False, select: str = 'a', select_range: tuple[float, float] | None = None, tol: float | None = None) -> Array: + """Solve the eigenvalue problem for a symmetric real tridiagonal matrix + + JAX implementation of :func:`scipy.linalg.eigh_tridiagonal`. + + Args: + d: real-valued array of shape ``(N,)`` specifying the diagonal elements. + e: real-valued array of shape ``(N - 1,)`` specifying the off-diagonal elements. + eigvals_only: If True, return only the eigenvalues (default: False). Computation + of eigenvectors is not yet implemented, so ``eigvals_only`` must be set to True. + select: specify which eigenvalues to calculate. Supported values are: + + - ``'a'``: all eigenvalues + - ``'i'``: eigenvalues with indices ``select_range[0] <= i <= select_range[1]`` + + JAX does not currently implement ``select = 'v'``. + select_range: range of values used when ``select='i'``. + tol: absolute tolerance to use when solving for the eigenvalues. + + Returns: + An array of eigenvalues with shape ``(N,)``. + + See also: + :func:`jax.scipy.linalg.eigh`: general Hermitian eigenvalue solver + + Examples: + >>> d = jnp.array([1., 2., 3., 4.]) + >>> e = jnp.array([1., 1., 1.]) + >>> eigvals = jax.scipy.linalg.eigh_tridiagonal(d, e, eigvals_only=True) + >>> eigvals + Array([0.2547188, 1.8227171, 3.1772828, 4.745281 ], dtype=float32) + + For comparison, we can construct the full matrix and compute the same result + using :func:`~jax.scipy.linalg.eigh`: + + >>> A = jnp.diag(d) + jnp.diag(e, 1) + jnp.diag(e, -1) + >>> A + Array([[1., 1., 0., 0.], + [1., 2., 1., 0.], + [0., 1., 3., 1.], + [0., 0., 1., 4.]], dtype=float32) + >>> eigvals_full = jax.scipy.linalg.eigh(A, eigvals_only=True) + >>> jnp.allclose(eigvals, eigvals_full) + Array(True, dtype=bool) + """ if not eigvals_only: raise NotImplementedError("Calculation of eigenvectors is not implemented") @@ -901,26 +1714,103 @@ def _sqrtm(A: ArrayLike) -> Array: return jnp.matmul(jnp.matmul(Z, sqrt_T, precision=lax.Precision.HIGHEST), jnp.conj(Z.T), precision=lax.Precision.HIGHEST) -@implements(scipy.linalg.sqrtm, - lax_description=""" -This differs from ``scipy.linalg.sqrtm`` in that the return type of -``jax.scipy.linalg.sqrtm`` is always ``complex64`` for 32-bit input, -and ``complex128`` for 64-bit input. - -This function implements the complex Schur method described in [A]. It does not use recursive blocking -to speed up computations as a Sylvester Equation solver is not available yet in JAX. -[A] Björck, Å., & Hammarling, S. (1983). - "A Schur method for the square root of a matrix". Linear algebra and its applications, 52, 127-140. -""") def sqrtm(A: ArrayLike, blocksize: int = 1) -> Array: + """Compute the matrix square root + + JAX implementation of :func:`scipy.linalg.sqrtm`. + + Args: + A: array of shape ``(N, N)`` + blocksize: Not supported in JAX; JAX always uses ``blocksize=1``. + + Returns: + An array of shape ``(N, N)`` containing the matrix square root of ``A`` + + See Also: + :func:`jax.scipy.linalg.expm` + + Example: + >>> a = jnp.array([[1., 2., 3.], + ... [2., 4., 2.], + ... [3., 2., 1.]]) + >>> sqrt_a = jax.scipy.linalg.sqrtm(a) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(sqrt_a) + [[0.92+0.71j 0.54+0.j 0.92-0.71j] + [0.54+0.j 1.85+0.j 0.54-0.j ] + [0.92-0.71j 0.54-0.j 0.92+0.71j]] + + By definition, matrix multiplication of the matrix square root with itself should + equal the input: + + >>> jnp.allclose(a, sqrt_a @ sqrt_a) + Array(True, dtype=bool) + + Notes: + This function implements the complex Schur method described in [1]_. It does not use + recursive blocking to speed up computations as a Sylvester Equation solver is not + yet available in JAX. + + References: + .. [1] Björck, Å., & Hammarling, S. (1983). "A Schur method for the square root of a matrix". + Linear algebra and its applications, 52, 127-140. + """ if blocksize > 1: raise NotImplementedError("Blocked version is not implemented yet.") return _sqrtm(A) -@implements(scipy.linalg.rsf2csf, lax_description=_no_chkfinite_doc) + @partial(jit, static_argnames=('check_finite',)) def rsf2csf(T: ArrayLike, Z: ArrayLike, check_finite: bool = True) -> tuple[Array, Array]: + """Convert real Schur form to complex Schur form. + + JAX implementation of :func:`scipy.linalg.rsf2csf`. + + Args: + T: array of shape ``(..., N, N)`` containing the real Schur form of the input. + Z: array of shape ``(..., N, N)`` containing the corresponding Schur transformation + matrix. + check_finite: unused by JAX + + Returns: + A tuple of arrays ``(T, Z)`` of the same shape as the inputs, containing the + Complex Schur form and the associated Schur transformation matrix. + + See Also: + :func:`jax.scipy.linalg.schur`: Schur decomposition + + Example: + >>> A = jnp.array([[0., 3., 3.], + ... [0., 1., 2.], + ... [2., 0., 1.]]) + >>> Tr, Zr = jax.scipy.linalg.schur(A) + >>> Tc, Zc = jax.scipy.linalg.rsf2csf(Tr, Zr) + + Both the real and complex form can be used to reconstruct the input matrix + to float32 precision: + + >>> jnp.allclose(Zr @ Tr @ Zr.T, A, atol=1E-5) + Array(True, dtype=bool) + >>> jnp.allclose(Zc @ Tc @ Zc.conj().T, A, atol=1E-5) + Array(True, dtype=bool) + + The real-valued Schur form is only quasi-upper-triangular, as we can see in this case: + + >>> with jax.numpy.printoptions(precision=2, suppress=True): + ... print(Tr) + [[ 3.76 -2.17 1.38] + [ 0. -0.88 -0.35] + [ 0. 2.37 -0.88]] + + By contrast, the complex form is truely upper-triangular: + + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(Tc) + [[ 3.76+0.j 1.29-0.78j 2.02-0.5j ] + [ 0. +0.j -0.88+0.91j -2.02+0.j ] + [ 0. +0.j 0. +0.j -0.88-0.91j]] + """ del check_finite # unused T_arr = jnp.asarray(T) @@ -987,11 +1877,57 @@ def hessenberg(a: ArrayLike, *, calc_q: Literal[False], overwrite_a: bool = Fals def hessenberg(a: ArrayLike, *, calc_q: Literal[True], overwrite_a: bool = False, check_finite: bool = True) -> tuple[Array, Array]: ... -@implements(scipy.linalg.hessenberg, lax_description=_no_overwrite_and_chkfinite_doc) + @partial(jit, static_argnames=('calc_q', 'check_finite', 'overwrite_a')) def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False, check_finite: bool = True) -> Array | tuple[Array, Array]: - del overwrite_a, check_finite + """Compute the Hessenberg form of the matrix + + JAX implementation of :func:`scipy.linalg.hessenberg`. + + The Hessenberg form `H` of a matrix `A` satisfies: + + .. math:: + + A = Q H Q^H + + where `Q` is unitary and `H` is zero below the first subdiagonal. + + Args: + a : array of shape ``(..., N, N)`` + calc_q: if True, calculate the ``Q`` matrix (default: False) + overwrite_a: unused by JAX + check_finite: unused by JAX + + Returns: + A tuple of arrays ``(H, Q)`` if ``calc_q`` is True, else an array ``H`` + + - ``H`` has shape ``(..., N, N)`` and is the Hessenberg form of ``a`` + - ``Q`` has shape ``(..., N, N)`` and is the associated unitary matrix + + Example: + Computing the Hessenberg form of a 4x4 matrix + + >>> a = jnp.array([[1., 2., 3., 4.], + ... [1., 4., 2., 3.], + ... [3., 2., 1., 4.], + ... [2., 3., 2., 2.]]) + >>> H, Q = jax.scipy.linalg.hessenberg(a, calc_q=True) + >>> with jnp.printoptions(suppress=True, precision=3): + ... print(H) + [[ 1. -5.078 1.167 1.361] + [-3.742 5.786 -3.613 -1.825] + [ 0. -2.992 2.493 -0.577] + [ 0. 0. -0.043 -1.279]] + + Notice the zeros in the subdiagonal positions. The original matrix + can be reconstructed using the ``Q`` vectors: + + >>> a_reconstructed = Q @ H @ Q.conj().T + >>> jnp.allclose(a_reconstructed, a) + Array(True, dtype=bool) + """ + del overwrite_a, check_finite # unused n = jnp.shape(a)[-1] if n == 0: if calc_q: @@ -1010,8 +1946,64 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False, else: return h -@implements(scipy.linalg.toeplitz) + def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: + r"""Construct a Toeplitz matrix + + JAX implementation of :func:`scipy.linalg.toeplitz`. + + A Toeplitz matrix has equal diagonals: :math:`A_{ij} = k_{i - j}` + for :math:`0 \le i < n` and :math:`0 \le j < n`. This function + specifies the diagonals via the first column ``c`` and the first row + ``r``, such that for row `i` and column `j`: + + .. math:: + + A_{ij} = \begin{cases} + c_{i - j} & i \ge j \\ + r_{j - i} & i < j + \end{cases} + + Notice this implies that :math:`r_0` is ignored. + + Args: + c: array specifying the first column. Will be flattened + if not 1-dimensional. + r: (optional) array specifying the first row. If not specified, defaults + to ``conj(c)``. Will be flattened if not 1-dimensional. + + Returns: + toeplitz matrix of shape ``(c.size, r.size)``. + + Examples: + Specifying ``c`` only: + + >>> c = jnp.array([1, 2, 3]) + >>> jax.scipy.linalg.toeplitz(c) + Array([[1, 2, 3], + [2, 1, 2], + [3, 2, 1]], dtype=int32) + + Specifying ``c`` and ``r``: + + >>> r = jnp.array([-1, -2, -3]) + >>> jax.scipy.linalg.toeplitz(c, r) # Note r[0] is ignored + Array([[ 1, -2, -3], + [ 2, 1, -2], + [ 3, 2, 1]], dtype=int32) + + If specifying only complex-valued ``c``, ``r`` defaults to ``c.conj()``, + resulting in a Hermitian matrix if ``c[0].imag == 0``: + + >>> c = jnp.array([1, 2+1j, 1+2j]) + >>> M = jax.scipy.linalg.toeplitz(c) + >>> M + Array([[1.+0.j, 2.-1.j, 1.-2.j], + [2.+1.j, 1.+0.j, 2.-1.j], + [1.+2.j, 2.+1.j, 1.+0.j]], dtype=complex64) + >>> print("M is Hermitian:", jnp.all(M == M.conj().T)) + M is Hermitian: True + """ if r is None: check_arraylike("toeplitz", c) r = jnp.conjugate(jnp.asarray(c)) @@ -1036,8 +2028,35 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: precision=lax.Precision.HIGHEST)[0] return jnp.flip(patches, axis=0) -@implements(scipy.linalg.hilbert) + @partial(jit, static_argnames=("n",)) def hilbert(n: int) -> Array: + r"""Create a Hilbert matrix of order n. + + JAX implementation of :func:`scipy.linalg.hilbert`. + + The Hilbert matrix is defined by: + + .. math:: + + H_{ij} = \frac{1}{i + j + 1} + + for :math:`1 \le i \le n` and :math:`1 \le j \le n`. + + Args: + n: the size of the matrix to create. + + Returns: + A Hilbert matrix of shape ``(n, n)`` + + Examples: + >>> jax.scipy.linalg.hilbert(2) + Array([[1. , 0.5 ], + [0.5 , 0.33333334]], dtype=float32) + >>> jax.scipy.linalg.hilbert(3) + Array([[1. , 0.5 , 0.33333334], + [0.5 , 0.33333334, 0.25 ], + [0.33333334, 0.25 , 0.2 ]], dtype=float32) + """ a = lax.broadcasted_iota(jnp.float64, (n, 1), 0) return 1/(a + a.T + 1)