Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.scipy.fft: manually document functions to avoid scipy import #20799

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 90 additions & 10 deletions jax/_src/scipy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
from functools import partial
import math

import scipy.fft as osp_fft
from jax import lax
import jax.numpy as jnp
from jax._src.util import canonicalize_axis
from jax._src.numpy.util import implements, promote_dtypes_complex
from jax._src.numpy.util import promote_dtypes_complex
from jax._src.typing import Array

def _W4(N: int, k: Array) -> Array:
Expand All @@ -42,9 +41,30 @@ def _dct_ortho_norm(out: Array, axis: int) -> Array:
# Implementation based on
# John Makhoul: A Fast Cosine Transform in One and Two Dimensions (1980)

@implements(osp_fft.dct)

def dct(x: Array, type: int = 2, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
"""Computes the discrete cosine transform of the input

JAX implementation of :func:`scipy.fft.dct`.

Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
n: integer, default = x.shape[axis]. The length of the transform.
If larger than ``x.shape[axis]``, the input will be zero-padded, if
smaller, the input will be truncated.
axis: integer, default=-1. The axis along which the dct will be performed.
norm: string. The normalization mode. Currently only ``"ortho"`` is supported.

Returns:
array containing the discrete cosine transform of x

See Also:
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
- :func:`jax.scipy.fft.idct`: inverse DCT
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')

Expand Down Expand Up @@ -81,11 +101,31 @@ def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array:
return out


@implements(osp_fft.dctn)
def dctn(x: Array, type: int = 2,
s: Sequence[int] | None=None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
"""Computes the multidimensional discrete cosine transform of the input

JAX implementation of :func:`scipy.fft.dctn`.

Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
s: integer or sequence of integers. Specifies the shape of the result. If not
specified, it will default to the shape of ``x`` along the specified ``axes``.
axes: integer or sequence of integers. Specifies the axes along which the
transform will be computed.
norm: string. The normalization mode. Currently only ``"ortho"`` is supported.

Returns:
array containing the discrete cosine transform of x

See Also:
- :func:`jax.scipy.fft.dct`: one-dimensional DCT
- :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')

Expand All @@ -109,9 +149,29 @@ def dctn(x: Array, type: int = 2,
return x


@implements(osp_fft.idct)
def idct(x: Array, type: int = 2, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
"""Computes the inverse discrete cosine transform of the input

JAX implementation of :func:`scipy.fft.idct`.

Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
n: integer, default = x.shape[axis]. The length of the transform.
If larger than ``x.shape[axis]``, the input will be zero-padded, if
smaller, the input will be truncated.
axis: integer, default=-1. The axis along which the dct will be performed.
norm: string. The normalization mode. Currently only ``"ortho"`` is supported.

Returns:
array containing the inverse discrete cosine transform of x

See Also:
- :func:`jax.scipy.fft.dct`: DCT
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')

Expand All @@ -126,7 +186,6 @@ def idct(x: Array, type: int = 2, n: int | None = None,
x = _dct_ortho_norm(x, axis)
x = _dct_ortho_norm(x, axis)


k = lax.expand_dims(jnp.arange(N, dtype=jnp.float32), [a for a in range(x.ndim) if a != axis])
# everything is complex from here...
w4 = _W4(N,k)
Expand All @@ -139,11 +198,32 @@ def idct(x: Array, type: int = 2, n: int | None = None,
out = _dct_deinterleave(x.real, axis)
return out

@implements(osp_fft.idctn)

def idctn(x: Array, type: int = 2,
s: Sequence[int] | None=None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
s: Sequence[int] | None=None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
"""Computes the multidimensional inverse discrete cosine transform of the input

JAX implementation of :func:`scipy.fft.idctn`.

Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
s: integer or sequence of integers. Specifies the shape of the result. If not
specified, it will default to the shape of ``x`` along the specified ``axes``.
axes: integer or sequence of integers. Specifies the axes along which the
transform will be computed.
norm: string. The normalization mode. Currently only ``"ortho"`` is supported.

Returns:
array containing the inverse discrete cosine transform of x

See Also:
- :func:`jax.scipy.fft.dct`: one-dimensional DCT
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
- :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')

Expand Down