Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring kernels #206

Merged
merged 8 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 17 additions & 29 deletions gpjax/kernels/stationary/matern12.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,57 +13,45 @@
# limitations under the License.
# ==============================================================================

from dataclasses import dataclass
from typing import Dict, List, Optional

import jax.numpy as jnp
from jax.random import KeyArray
from jaxtyping import Array, Float
import distrax as dx

from ...parameters import Softplus, param_field
from ..base import AbstractKernel
from ..computations import (
DenseKernelComputation,
)
from .utils import euclidean_distance, build_student_t_distribution
from ..computations import DenseKernelComputation
from .utils import build_student_t_distribution, euclidean_distance


@dataclass
class Matern12(AbstractKernel):
"""The Matérn kernel with smoothness parameter fixed at 0.5."""

def __init__(
self,
active_dims: Optional[List[int]] = None,
name: Optional[str] = "Matérn 1/2 kernel",
) -> None:
spectral_density = build_student_t_distribution(nu=1)
super().__init__(DenseKernelComputation, active_dims, spectral_density, name)
self._stationary = True
lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus)
variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus)

def __call__(
self,
params: Dict,
x: Float[Array, "1 D"],
y: Float[Array, "1 D"],
) -> Float[Array, "1"]:
def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]:
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with
lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2`

.. math::
k(x, y) = \\sigma^2 \\exp \\Bigg( -\\frac{\\lvert x-y \\rvert}{2\\ell^2} \\Bigg)

Args:
params (Dict): Parameter set for which the kernel should be evaluated on.
x (Float[Array, "1 D"]): The left hand argument of the kernel function's call.
y (Float[Array, "1 D"]): The right hand argument of the kernel function's call
x (Float[Array, "D"]): The left hand argument of the kernel function's call.
y (Float[Array, "D"]): The right hand argument of the kernel function's call
Returns:
Float[Array, "1"]: The value of :math:`k(x, y)`
"""
x = self.slice_input(x) / params["lengthscale"]
y = self.slice_input(y) / params["lengthscale"]
K = params["variance"] * jnp.exp(-euclidean_distance(x, y))
x = self.slice_input(x) / self.lengthscale
y = self.slice_input(y) / self.lengthscale
K = self.variance * jnp.exp(-euclidean_distance(x, y))
return K.squeeze()

def init_params(self, key: KeyArray) -> Dict:
return {
"lengthscale": jnp.array([1.0] * self.ndims),
"variance": jnp.array([1.0]),
}
@property
def spectral_density(self) -> dx.Distribution:
return build_student_t_distribution(nu=1)
47 changes: 17 additions & 30 deletions gpjax/kernels/stationary/matern32.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,30 @@
# limitations under the License.
# ==============================================================================

from dataclasses import dataclass
from typing import Dict, List, Optional

import jax.numpy as jnp
from jax.random import KeyArray
from jaxtyping import Array, Float

from ...parameters import Softplus, param_field
from ..base import AbstractKernel
from ..computations import (
DenseKernelComputation,
)
from .utils import euclidean_distance, build_student_t_distribution
from ..computations import DenseKernelComputation
from .utils import build_student_t_distribution, euclidean_distance


@dataclass
class Matern32(AbstractKernel):
"""The Matérn kernel with smoothness parameter fixed at 1.5."""

def __init__(
self,
active_dims: Optional[List[int]] = None,
name: Optional[str] = "Matern 3/2",
) -> None:
spectral_density = build_student_t_distribution(nu=3)
super().__init__(DenseKernelComputation, active_dims, spectral_density, name)
self._stationary = True
lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus)
variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus)

def __call__(
self,
params: Dict,
x: Float[Array, "1 D"],
y: Float[Array, "1 D"],
x: Float[Array, "D"],
y: Float[Array, "D"],
) -> Float[Array, "1"]:
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with
lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2`
Expand All @@ -51,25 +45,18 @@ def __call__(
k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{3}\\lvert x-y \\rvert}{\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{3}\\lvert x-y\\rvert}{\\ell^2} \\Bigg)

Args:
params (Dict): Parameter set for which the kernel should be evaluated on.
x (Float[Array, "1 D"]): The left hand argument of the kernel function's call.
y (Float[Array, "1 D"]): The right hand argument of the kernel function's call.
x (Float[Array, "D"]): The left hand argument of the kernel function's call.
y (Float[Array, "D"]): The right hand argument of the kernel function's call.

Returns:
Float[Array, "1"]: The value of :math:`k(x, y)`.
"""
x = self.slice_input(x) / params["lengthscale"]
y = self.slice_input(y) / params["lengthscale"]
x = self.slice_input(x) / self.lengthscale
y = self.slice_input(y) / self.lengthscale
tau = euclidean_distance(x, y)
K = (
params["variance"]
* (1.0 + jnp.sqrt(3.0) * tau)
* jnp.exp(-jnp.sqrt(3.0) * tau)
)
K = self.variance * (1.0 + jnp.sqrt(3.0) * tau) * jnp.exp(-jnp.sqrt(3.0) * tau)
return K.squeeze()

def init_params(self, key: KeyArray) -> Dict:
return {
"lengthscale": jnp.array([1.0] * self.ndims),
"variance": jnp.array([1.0]),
}
@property
def spectral_density(self):
return build_student_t_distribution(nu=3)
39 changes: 16 additions & 23 deletions gpjax/kernels/stationary/matern52.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,28 @@
# limitations under the License.
# ==============================================================================

from dataclasses import dataclass
from typing import Dict, List, Optional

import jax.numpy as jnp
from jax.random import KeyArray
from jaxtyping import Array, Float

from ...parameters import Softplus, param_field
from ..base import AbstractKernel
from ..computations import (
DenseKernelComputation,
)
from .utils import euclidean_distance, build_student_t_distribution
from ..computations import DenseKernelComputation
from .utils import build_student_t_distribution, euclidean_distance


@dataclass
class Matern52(AbstractKernel):
"""The Matérn kernel with smoothness parameter fixed at 2.5."""

def __init__(
self,
active_dims: Optional[List[int]] = None,
name: Optional[str] = "Matern 5/2",
) -> None:
spectral_density = build_student_t_distribution(nu=5)
super().__init__(DenseKernelComputation, active_dims, spectral_density, name)
lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus)
variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus)

def __call__(
self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"]
self, x: Float[Array, "D"], y: Float[Array, "D"]
) -> Float[Array, "1"]:
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with
lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2`
Expand All @@ -47,25 +43,22 @@ def __call__(
k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{5}\\lvert x-y \\rvert}{\\ell^2} + \\frac{5\\lvert x - y \\rvert^2}{3\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{5}\\lvert x-y\\rvert}{\\ell^2} \\Bigg)

Args:
params (Dict): Parameter set for which the kernel should be evaluated on.
x (Float[Array, "1 D"]): The left hand argument of the kernel function's call.
y (Float[Array, "1 D"]): The right hand argument of the kernel function's call.
x (Float[Array, "D"]): The left hand argument of the kernel function's call.
y (Float[Array, "D"]): The right hand argument of the kernel function's call.

Returns:
Float[Array, "1"]: The value of :math:`k(x, y)`.
"""
x = self.slice_input(x) / params["lengthscale"]
y = self.slice_input(y) / params["lengthscale"]
x = self.slice_input(x) / self.lengthscale
y = self.slice_input(y) / self.lengthscale
tau = euclidean_distance(x, y)
K = (
params["variance"]
self.variance
* (1.0 + jnp.sqrt(5.0) * tau + 5.0 / 3.0 * jnp.square(tau))
* jnp.exp(-jnp.sqrt(5.0) * tau)
)
return K.squeeze()

def init_params(self, key: KeyArray) -> Dict:
return {
"lengthscale": jnp.array([1.0] * self.ndims),
"variance": jnp.array([1.0]),
}
@property
def spectral_density(self):
return build_student_t_distribution(nu=5)
42 changes: 16 additions & 26 deletions gpjax/kernels/stationary/periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,54 +18,44 @@
import jax
import jax.numpy as jnp
from jax.random import KeyArray
from jaxtyping import Array
from jaxtyping import Array, Float

from ..base import AbstractKernel
from ..computations import (
DenseKernelComputation,
)

from dataclasses import dataclass
from ...parameters import param_field, Softplus


@dataclass
class Periodic(AbstractKernel):
"""The periodic kernel.

Key reference is MacKay 1998 - "Introduction to Gaussian processes".
"""

def __init__(
self,
active_dims: Optional[List[int]] = None,
name: Optional[str] = "Periodic",
) -> None:
super().__init__(
DenseKernelComputation, active_dims, spectral_density=None, name=name
)
self._stationary = True
lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus)
variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus)
period: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus)

def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array:
def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]:
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma`

TODO: update docstring

.. math::
k(x, y) = \\sigma^2 \\exp \\Bigg( -0.5 \\sum_{i=1}^{d} \\Bigg)

Args:
x (jax.Array): The left hand argument of the kernel function's call.
y (jax.Array): The right hand argument of the kernel function's call
params (dict): Parameter set for which the kernel should be evaluated on.
x (Float[Array, "D"]): The left hand argument of the kernel function's call.
y (Float[Array, "D"]): The right hand argument of the kernel function's call
Returns:
Array: The value of :math:`k(x, y)`
Float[Array, "1"]: The value of :math:`k(x, y)`
"""
x = self.slice_input(x)
y = self.slice_input(y)
sine_squared = (
jnp.sin(jnp.pi * (x - y) / params["period"]) / params["lengthscale"]
) ** 2
K = params["variance"] * jnp.exp(-0.5 * jnp.sum(sine_squared, axis=0))
sine_squared = (jnp.sin(jnp.pi * (x - y) / self.period) / self.lengthscale) ** 2
K = self.variance * jnp.exp(-0.5 * jnp.sum(sine_squared, axis=0))
return K.squeeze()

def init_params(self, key: KeyArray) -> Dict:
return {
"lengthscale": jnp.array([1.0] * self.ndims),
"variance": jnp.array([1.0]),
"period": jnp.array([1.0] * self.ndims),
}
43 changes: 15 additions & 28 deletions gpjax/kernels/stationary/powered_exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,59 +13,46 @@
# limitations under the License.
# ==============================================================================

from dataclasses import dataclass
from typing import Dict, List, Optional

import jax
import jax.numpy as jnp
from jax.random import KeyArray
from jaxtyping import Array
from jaxtyping import Array, Float

from ...parameters import Softplus, param_field
from ..base import AbstractKernel
from ..computations import (
DenseKernelComputation,
)
from ..computations import DenseKernelComputation
from .utils import euclidean_distance


@dataclass
class PoweredExponential(AbstractKernel):
"""The powered exponential family of kernels.

Key reference is Diggle and Ribeiro (2007) - "Model-based Geostatistics".

"""

def __init__(
self,
active_dims: Optional[List[int]] = None,
name: Optional[str] = "Powered exponential",
) -> None:
super().__init__(
DenseKernelComputation, active_dims, spectral_density=None, name=name
)
self._stationary = True
lengthscale: Float[Array, "D"] = param_field(jnp.array([1.0]), bijector=Softplus)
variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=Softplus)
power: Float[Array, "1"] = param_field(jnp.array([1.0]))

def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array:
def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]:
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell`, :math:`\\sigma` and power :math:`\\kappa`.

.. math::
k(x, y) = \\sigma^2 \\exp \\Bigg( - \\Big( \\frac{\\lVert x - y \\rVert^2}{\\ell^2} \\Big)^\\kappa \\Bigg)

Args:
x (jax.Array): The left hand argument of the kernel function's call.
y (jax.Array): The right hand argument of the kernel function's call
params (dict): Parameter set for which the kernel should be evaluated on.
x (Float[Array, "D"]): The left hand argument of the kernel function's call.
y (Float[Array, "D"]): The right hand argument of the kernel function's call

Returns:
Array: The value of :math:`k(x, y)`
Float[Array, "1"]: The value of :math:`k(x, y)`
"""
x = self.slice_input(x) / params["lengthscale"]
y = self.slice_input(y) / params["lengthscale"]
K = params["variance"] * jnp.exp(-euclidean_distance(x, y) ** params["power"])
x = self.slice_input(x) / self.lengthscale
y = self.slice_input(y) / self.lengthscale
K = self.variance * jnp.exp(-euclidean_distance(x, y) ** self.power)
return K.squeeze()

def init_params(self, key: KeyArray) -> Dict:
return {
"lengthscale": jnp.array([1.0] * self.ndims),
"variance": jnp.array([1.0]),
"power": jnp.array([1.0]),
}
Loading