Skip to content

Commit

Permalink
[skip ci] change how dimensions are specified for kernels, update ker…
Browse files Browse the repository at this point in the history
…nel tests
  • Loading branch information
frazane committed Mar 12, 2024
1 parent d52530e commit 15132dd
Show file tree
Hide file tree
Showing 34 changed files with 630 additions and 656 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ repos:
types: [python]
- id: ruff
name: ruff
entry: ruff
entry: ruff check
args: ["--exit-non-zero-on-fix"]
require_serial: true
language: system
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ D = gpx.Dataset(X=x, y=y)

# Construct the prior
meanf = gpx.mean_functions.Zero()
kernel = gpx.kernels.RBF(1)
kernel = gpx.kernels.RBF()
prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)

# Define a likelihood
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/barycentres.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:
likelihood = gpx.likelihoods.Gaussian(num_datapoints=n)
posterior = (
gpx.gps.Prior(
mean_function=gpx.mean_functions.Constant(), kernel=gpx.kernels.RBF(1)
mean_function=gpx.mean_functions.Constant(), kernel=gpx.kernels.RBF()
)
* likelihood
)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
# choose a Bernoulli likelihood with a probit link function.

# %%
kernel = gpx.kernels.RBF(1)
kernel = gpx.kernels.RBF()
meanf = gpx.mean_functions.Constant()
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
likelihood = gpx.likelihoods.Bernoulli(num_datapoints=D.n)
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/collapsed_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@

# %%
meanf = gpx.mean_functions.Constant()
kernel = gpx.kernels.RBF(active_dims=1) # 1-dimensional inputs
kernel = gpx.kernels.RBF() # 1-dimensional inputs
likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
posterior = prior * likelihood
Expand Down Expand Up @@ -221,7 +221,7 @@

# %%
full_rank_model = gpx.gps.Prior(
mean_function=gpx.mean_functions.Zero(), kernel=gpx.kernels.RBF(1)
mean_function=gpx.mean_functions.Zero(), kernel=gpx.kernels.RBF()
) * gpx.likelihoods.Gaussian(num_datapoints=D.n)
nmll = jit(lambda: -gpx.objectives.conjugate_mll(full_rank_model, D))
# %timeit nmll().block_until_ready()
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
# kernel.

# %%
kernel = gpx.kernels.RBF(active_dims=1) # 1-dimensional input
kernel = gpx.kernels.RBF() # 1-dimensional input
meanf = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)

Expand Down
2 changes: 1 addition & 1 deletion docs/examples/uncollapsed_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@
# %%
meanf = gpx.mean_functions.Zero()
likelihood = gpx.likelihoods.Gaussian(num_datapoints=n)
kernel = jk.RBF(active_dims=1) # 1-dimensional inputs
kernel = jk.RBF() # 1-dimensional inputs
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
p = prior * likelihood
q = gpx.variational_families.VariationalGaussian(posterior=p, inducing_inputs=z)
Expand Down
2 changes: 0 additions & 2 deletions gpjax/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from flax.experimental import nnx
import jax
from jax._src.random import _check_prng_key
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import jax.random as jr
Expand Down Expand Up @@ -124,7 +123,6 @@ def fit( # noqa: PLR0913
_check_optim(optim)
_check_num_iters(num_iters)
_check_batch_size(batch_size)
_check_prng_key(key)
_check_log_rate(log_rate)
_check_verbose(verbose)

Expand Down
12 changes: 6 additions & 6 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class Prior(AbstractPrior[M, K]):
Example:
```python
>>> import gpjax as gpx
>>> kernel = gpx.kernels.RBF(active_dims=1)
>>> kernel = gpx.kernels.RBF()
>>> meanf = gpx.mean_functions.Zero()
>>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
```
Expand Down Expand Up @@ -178,7 +178,7 @@ def __mul__(self, other): # noqa: F811
Example:
>>> import gpjax as gpx
>>> meanf = gpx.mean_functions.Zero()
>>> kernel = gpx.kernels.RBF(1)
>>> kernel = gpx.kernels.RBF()
>>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
>>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100)
>>> prior * likelihood
Expand Down Expand Up @@ -239,7 +239,7 @@ def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
Example:
>>> import gpjax as gpx
>>> import jax.numpy as jnp
>>> kernel = gpx.kernels.RBF(1)
>>> kernel = gpx.kernels.RBF()
>>> mean_function = gpx.mean_functions.Zero()
>>> prior = gpx.gps.Prior(mean_function=mean_function, kernel=kernel)
>>> prior.predict(jnp.linspace(0, 1, 100)[:, None])
Expand Down Expand Up @@ -297,7 +297,7 @@ def sample_approx(
>>> key = jr.PRNGKey(123)
...
>>> meanf = gpx.mean_functions.Zero()
>>> kernel = gpx.kernels.RBF(active_dims=1)
>>> kernel = gpx.kernels.RBF()
>>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
...
>>> sample_fn = prior.sample_approx(10, key)
Expand Down Expand Up @@ -436,7 +436,7 @@ class ConjugatePosterior(AbstractPosterior[P, GL]):
...
>>> prior = gpx.gps.Prior(
mean_function = gpx.mean_functions.Zero(),
kernel = gpx.kernels.RBF(active_dims=1)
kernel = gpx.kernels.RBF()
)
>>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100)
...
Expand Down Expand Up @@ -478,7 +478,7 @@ def predict(
>>> D = gpx.Dataset(X=xtrain, y=ytrain)
>>> xtest = jnp.linspace(0, 1).reshape(-1, 1)
...
>>> prior = gpx.gps.Prior(mean_function = gpx.mean_functions.Zero(), kernel = gpx.kernels.RBF(1))
>>> prior = gpx.gps.Prior(mean_function = gpx.mean_functions.Zero(), kernel = gpx.kernels.RBF())
>>> posterior = prior * gpx.likelihoods.Gaussian(num_datapoints = D.n)
>>> predictive_dist = posterior(xtest, D)
Expand Down
6 changes: 6 additions & 0 deletions gpjax/kernels/approximations/rff.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def __init__(

if self.frequencies is None:
n_dims = self.base_kernel.n_dims
if n_dims is None:
raise ValueError(
"Expected the number of dimensions to be specified for the base kernel. "
"Please specify the n_dims argument for the base kernel."
)

self.frequencies = Static(
self.base_kernel.spectral_density.sample(
seed=key, sample_shape=(self.num_basis_fns, n_dims)
Expand Down
127 changes: 69 additions & 58 deletions gpjax/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,45 +49,44 @@ class AbstractKernel(nnx.Module):
relevant columns for the kernel's evaluation.
Attributes:
active_dims (tp.Union[list[int], slice]): The indices of the input dimensions
that are active in the kernel's evaluation. If active_dims is a list, then
the input to the kernel is indexed by the list, and n_dims
is the length of the list. If active_dims is an integer, then the input to the
kernel is not indexed, and n_dims is the value of the integer.
If active_dims is a slice, then the input to the kernel is indexed by the slice,
and n_dims is the length of the slice. Importantly, n_dims must always be
inferable from active_dims.
compute_engine (AbstractKernelComputation): The computation engine that is used to
active_dims: The indices of the input dimensions
that are active in the kernel's evaluation, represented by a list of integers
or a slice object.
compute_engine: The computation engine that is used to
compute the kernel's cross-covariance and gram matrices.
n_dims (int): The number of input dimensions of the kernel.
name (str): The name of the kernel.
n_dims: The number of input dimensions of the kernel.
name: The name of the kernel.
"""

active_dims: tp.Union[list[int], slice]
active_dims: tp.Union[list[int], slice] = slice(None)
compute_engine: AbstractKernelComputation
n_dims: int
n_dims: tp.Union[int, None]
name: str = "AbstractKernel"

def __init__(
self,
active_dims: tp.Union[list[int], int, slice],
active_dims: tp.Union[list[int], slice, None] = None,
n_dims: tp.Union[int, None] = None,
compute_engine: AbstractKernelComputation = DenseKernelComputation(),
):
"""Initialise the AbstractKernel class.
Args:
active_dims (tp.Union[list[int], int, slice]): The indices of the input dimensions
that are active in the kernel's evaluation. If active_dims is a list, then
the input to the kernel is indexed by the list, and the number of input dimensions
is the length of the list. If active_dims is an integer, then the input to the
kernel is not indexed, and the number of input dimensions is the value of the integer.
If active_dims is a slice, then the input to the kernel is indexed by the slice,
and the number of input dimensions is the length of the slice. Importantly, the number
of active dimensions must be inferable from active_dims.
compute_engine (AbstractKernelComputation): The computation engine that is used to
compute the kernel's cross-covariance and gram matrices.
active_dims: the indices of the input dimensions
that are active in the kernel's evaluation, represented by a list of
integers or a slice object. Defaults to a full slice.
n_dims: the number of input dimensions of the kernel.
compute_engine: the computation engine that is used to compute the kernel's
cross-covariance and gram matrices. Defaults to DenseKernelComputation.
"""
self.n_dims, self.active_dims = _check_active_dims(active_dims)

active_dims = active_dims or slice(None)

_check_active_dims(active_dims)
_check_n_dims(n_dims)

self.active_dims, self.n_dims = _check_dims_compat(active_dims, n_dims)

self.compute_engine = compute_engine

def cross_covariance(self, x: Num[Array, "N D"], y: Num[Array, "M D"]):
Expand Down Expand Up @@ -184,7 +183,7 @@ class Constant(AbstractKernel):

def __init__(
self,
active_dims: tp.Union[list[int], int, slice, None] = 1,
active_dims: tp.Union[list[int], slice, None] = None,
constant: tp.Union[ScalarFloat, Parameter[ScalarFloat]] = jnp.array(0.0),
compute_engine: AbstractKernelComputation = DenseKernelComputation(),
):
Expand Down Expand Up @@ -232,9 +231,7 @@ def __init__(
self.kernels = kernels_list
self.operator = operator

active_dims = ft.reduce(lambda asum, x: asum + x.n_dims, kernels_list, 0)

super().__init__(active_dims=active_dims, compute_engine=compute_engine)
super().__init__(compute_engine=compute_engine)

def __call__(
self,
Expand All @@ -254,40 +251,54 @@ def __call__(
return self.operator(jnp.stack([k(x, y) for k in self.kernels]))


@tp.overload
def _check_active_dims(active_dims: list[int]) -> tuple[int, list[int]]:
...


@tp.overload
def _check_active_dims(active_dims: int) -> tuple[int, slice]: # noqa: F811
...
def _check_active_dims(active_dims: tp.Any):
if not isinstance(active_dims, (list, slice)):
raise TypeError(
f"Expected active_dims to be a list or slice. Got {active_dims} instead."
)


@tp.overload
def _check_active_dims(active_dims: slice) -> tuple[int, slice]: # noqa: F811
...
def _check_n_dims(n_dims: tp.Any):
if not isinstance(n_dims, (int, type(None))):
raise TypeError(
"Expected n_dims to be an integer or None (unspecified)."
f" Got {n_dims} instead."
)


def _check_active_dims(active_dims: tp.Union[list[int], int, slice]): # noqa: F811
if isinstance(active_dims, list):
return len(active_dims), active_dims
elif isinstance(active_dims, int):
return active_dims, slice(None)
elif isinstance(active_dims, slice):
if active_dims.stop is None:
raise ValueError("active_dims slice must have a stop value.")
if active_dims.stop < 0:
raise ValueError("active_dims slice stop value must be positive.")
def _check_dims_compat(
active_dims: tp.Union[list[int], slice],
n_dims: tp.Union[int, None],
):
err = ValueError(
"Expected the length of active_dims to be equal to the specified n_dims."
f" Got {active_dims} active dimensions and {n_dims} input dimensions."
)

if isinstance(active_dims, list) and isinstance(n_dims, int):
if len(active_dims) != n_dims:
raise err

if isinstance(active_dims, slice) and isinstance(n_dims, int):
start = active_dims.start or 0
stop = active_dims.stop or n_dims
step = active_dims.step or 1
if len(range(start, stop, step)) != n_dims:
raise err

if isinstance(active_dims, list) and n_dims is None:
n_dims = len(active_dims)

if isinstance(active_dims, slice) and n_dims is None:
if active_dims == slice(None):
pass
else:
start = active_dims.start or 0
stop = active_dims.stop or n_dims
step = active_dims.step or 1
n_dims = len(range(start, stop, step))

start = active_dims.start if active_dims.start is not None else 0
step = active_dims.step if active_dims.step is not None else 1
return (active_dims.stop - start) // step, active_dims
else:
raise TypeError(
"Expected active_dims to be a list, int or slice."
f" Got {type(active_dims)} instead."
)
return active_dims, n_dims


SumKernel = ft.partial(CombinationKernel, operator=jnp.sum)
Expand Down
5 changes: 3 additions & 2 deletions gpjax/kernels/non_euclidean/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ class GraphKernel(StationaryKernel):
def __init__(
self,
laplacian: Num[Array, "N N"],
active_dims: tp.Union[list[int], int, slice],
active_dims: tp.Union[list[int], slice, None] = None,
lengthscale: tp.Union[ScalarFloat, Float[Array, " D"], Parameter] = 1.0,
variance: tp.Union[ScalarFloat, Parameter] = 1.0,
smoothness: ScalarFloat = 1.0,
n_dims: tp.Union[int, None] = None,
compute_engine: AbstractKernelComputation = EigenKernelComputation(),
):
if isinstance(smoothness, Parameter):
Expand All @@ -76,7 +77,7 @@ def __init__(
self.eigenvalues = Static(evals.reshape(-1, 1))
self.num_vertex = self.eigenvalues.value.shape[0]

super().__init__(active_dims, lengthscale, variance, compute_engine)
super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)

def __call__( # TODO not consistent with general kernel interface
self,
Expand Down
5 changes: 3 additions & 2 deletions gpjax/kernels/nonstationary/arccosine.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,14 @@ class ArcCosine(AbstractKernel):

def __init__(
self,
active_dims: tp.Union[list[int], int, slice],
active_dims: tp.Union[list[int], slice, None] = None,
order: tp.Literal[0, 1, 2] = 0,
variance: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] = 1.0,
weight_variance: tp.Union[
WeightVarianceCompatible, nnx.Variable[WeightVariance]
] = 1.0,
bias_variance: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] = 1.0,
n_dims: tp.Union[int, None] = None,
compute_engine: AbstractKernelComputation = DenseKernelComputation(),
):
if order not in [0, 1, 2]:
Expand Down Expand Up @@ -91,7 +92,7 @@ def __init__(

self.name = f"ArcCosine (order {self.order})"

super().__init__(active_dims=active_dims, compute_engine=compute_engine)
super().__init__(active_dims, n_dims, compute_engine)

def __call__(self, x: Float[Array, " D"], y: Float[Array, " D"]) -> ScalarArray:
r"""Evaluate the kernel on a pair of inputs $`(x, y)`$
Expand Down
5 changes: 3 additions & 2 deletions gpjax/kernels/nonstationary/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ class Linear(AbstractKernel):

def __init__(
self,
active_dims: tp.Union[list[int], int, slice],
active_dims: tp.Union[list[int], slice, None] = None,
variance: tp.Union[ScalarFloat, nnx.Variable[ScalarArray]] = 1.0,
n_dims: tp.Union[int, None] = None,
compute_engine: AbstractKernelComputation = DenseKernelComputation(),
):
super().__init__(active_dims=active_dims, compute_engine=compute_engine)
super().__init__(active_dims, n_dims, compute_engine)

if isinstance(variance, nnx.Variable):
self.variance = variance
Expand Down
Loading

0 comments on commit 15132dd

Please sign in to comment.