Skip to content

Commit

Permalink
jax.scipy.fft: manually document functions to avoid scipy import
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 29, 2024
1 parent ba540ca commit 7592116
Showing 1 changed file with 90 additions and 10 deletions.
100 changes: 90 additions & 10 deletions jax/_src/scipy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
from functools import partial
import math

import scipy.fft as osp_fft
from jax import lax
import jax.numpy as jnp
from jax._src.util import canonicalize_axis
from jax._src.numpy.util import implements, promote_dtypes_complex
from jax._src.numpy.util import promote_dtypes_complex
from jax._src.typing import Array

def _W4(N: int, k: Array) -> Array:
Expand All @@ -42,9 +41,30 @@ def _dct_ortho_norm(out: Array, axis: int) -> Array:
# Implementation based on
# John Makhoul: A Fast Cosine Transform in One and Two Dimensions (1980)

@implements(osp_fft.dct)

def dct(x: Array, type: int = 2, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
"""Computes the discrete cosine transform of the input
JAX implementation of :func:`scipy.fft.dct`.
Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
n: integer, default = x.shape[axis]. The length of the transform.
If larger than ``x.shape[axis]``, the input will be zero-padded, if
smaller, the input will be truncated.
axis: integer, default=-1. The axis along which the dct will be performed.
norm: string. The normalization mode. Currently only ``"ortho"`` is supported.
Returns:
array containing the discrete cosine transform of x
See Also:
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
- :func:`jax.scipy.fft.idct`: inverse DCT
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')

Expand Down Expand Up @@ -81,11 +101,31 @@ def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array:
return out


@implements(osp_fft.dctn)
def dctn(x: Array, type: int = 2,
s: Sequence[int] | None=None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
"""Computes the multidimensional discrete cosine transform of the input
JAX implementation of :func:`scipy.fft.dctn`.
Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
s: integer or sequence of integers. Specifies the shape of the result. If not
specified, it will default to the shape of ``x`` along the specified ``axes``.
axes: integer or sequence of integers. Specifies the axes along which the
transform will be computed.
norm: string. The normalization mode. Currently only ``"ortho"`` is supported.
Returns:
array containing the discrete cosine transform of x
See Also:
- :func:`jax.scipy.fft.dct`: one-dimensional DCT
- :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')

Expand All @@ -109,9 +149,29 @@ def dctn(x: Array, type: int = 2,
return x


@implements(osp_fft.idct)
def idct(x: Array, type: int = 2, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
"""Computes the inverse discrete cosine transform of the input
JAX implementation of :func:`scipy.fft.idct`.
Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
n: integer, default = x.shape[axis]. The length of the transform.
If larger than ``x.shape[axis]``, the input will be zero-padded, if
smaller, the input will be truncated.
axis: integer, default=-1. The axis along which the dct will be performed.
norm: string. The normalization mode. Currently only ``"ortho"`` is supported.
Returns:
array containing the inverse discrete cosine transform of x
See Also:
- :func:`jax.scipy.fft.dct`: DCT
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')

Expand All @@ -126,7 +186,6 @@ def idct(x: Array, type: int = 2, n: int | None = None,
x = _dct_ortho_norm(x, axis)
x = _dct_ortho_norm(x, axis)


k = lax.expand_dims(jnp.arange(N, dtype=jnp.float32), [a for a in range(x.ndim) if a != axis])
# everything is complex from here...
w4 = _W4(N,k)
Expand All @@ -139,11 +198,32 @@ def idct(x: Array, type: int = 2, n: int | None = None,
out = _dct_deinterleave(x.real, axis)
return out

@implements(osp_fft.idctn)

def idctn(x: Array, type: int = 2,
s: Sequence[int] | None=None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
s: Sequence[int] | None=None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
"""Computes the multidimensional inverse discrete cosine transform of the input
JAX implementation of :func:`scipy.fft.idctn`.
Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
s: integer or sequence of integers. Specifies the shape of the result. If not
specified, it will default to the shape of ``x`` along the specified ``axes``.
axes: integer or sequence of integers. Specifies the axes along which the
transform will be computed.
norm: string. The normalization mode. Currently only ``"ortho"`` is supported.
Returns:
array containing the inverse discrete cosine transform of x
See Also:
- :func:`jax.scipy.fft.dct`: one-dimensional DCT
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
- :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')

Expand Down

0 comments on commit 7592116

Please sign in to comment.