Skip to content

Commit

Permalink
optional solver_algorithm parameter to sample_approx (#478)
Browse files Browse the repository at this point in the history
* - Added an optional solver algorithm parameter to the `ConjugatePosterior` class for solving the inverse of the covariance matrix.
- Updated the `test_conjugate_posterior_sample_approx` test function to include the `solver_algorithm` parameter.

* remove commented import line

* bump the version
  • Loading branch information
theorashid authored Sep 18, 2024
1 parent b03bdb5 commit 6eaab4a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
2 changes: 1 addition & 1 deletion gpjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
__description__ = "Didactic Gaussian processes in JAX"
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
__version__ = "0.9.1"
__version__ = "0.9.2"

__all__ = [
"base",
Expand Down
9 changes: 8 additions & 1 deletion gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import beartype.typing as tp
from cola.annotations import PSD
from cola.linalg.algorithm_base import Algorithm
from cola.linalg.decompositions.decompositions import Cholesky
from cola.linalg.inverse.inv import solve
from cola.ops.operators import I_like
Expand Down Expand Up @@ -530,6 +531,7 @@ def sample_approx(
train_data: Dataset,
key: KeyArray,
num_features: int | None = 100,
solver_algorithm: tp.Optional[Algorithm] = Cholesky(),
) -> FunctionalSample:
r"""Draw approximate samples from the Gaussian process posterior.
Expand Down Expand Up @@ -563,6 +565,11 @@ def sample_approx(
key (KeyArray): The random seed used for the sample(s).
num_features (int): The number of features used when approximating the
kernel.
solver_algorithm (Optional[Algorithm], optional): The algorithm to use for the solves of
the inverse of the covariance matrix. See the
[CoLA documentation](https://cola.readthedocs.io/en/latest/package/cola.linalg.html#algorithms)
for which solver to pick. For PSD matrices, CoLA currently recommends Cholesky() for small
matrices and CG() for larger matrices. Select Auto() to let CoLA decide. Defaults to Cholesky().
Returns:
FunctionalSample: A function representing an approximate sample from the Gaussian
Expand All @@ -588,7 +595,7 @@ def sample_approx(
canonical_weights = solve(
Sigma,
y + eps - jnp.inner(Phi, fourier_weights),
Cholesky(),
solver_algorithm,
) # [N, B]

def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
Expand Down
19 changes: 13 additions & 6 deletions tests/test_gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
Type,
)

from cola.linalg.algorithm_base import Auto
from cola.linalg.decompositions.decompositions import Cholesky
from cola.linalg.inverse.cg import CG
from jax import config
import jax.numpy as jnp
import jax.random as jr
import pytest
import tensorflow_probability.substrates.jax.distributions as tfd

# from gpjax.dataset import Dataset
from gpjax.dataset import Dataset
from gpjax.distributions import GaussianDistribution
from gpjax.gps import (
Expand Down Expand Up @@ -283,7 +285,10 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function):
@pytest.mark.parametrize("num_datapoints", [1, 5])
@pytest.mark.parametrize("kernel", [RBF, Matern52])
@pytest.mark.parametrize("mean_function", [Zero, Constant])
def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function):
@pytest.mark.parametrize("solver_algorithm", [Cholesky(), CG(), Auto()])
def test_conjugate_posterior_sample_approx(
num_datapoints, kernel, mean_function, solver_algorithm
):
kern = kernel(lengthscale=jnp.array([5.0, 1.0]), variance=0.1)
p = Prior(kernel=kern, mean_function=mean_function()) * Gaussian(
num_datapoints=num_datapoints
Expand All @@ -310,26 +315,28 @@ def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function
# with pytest.raises(ValidationErrors):
# p.sample_approx(1, D, key, 0.5)

sampled_fn = p.sample_approx(1, D, key, 100)
sampled_fn = p.sample_approx(1, D, key, 100, solver_algorithm=solver_algorithm)
assert isinstance(sampled_fn, Callable) # check type

x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2))
evals = sampled_fn(x)
assert evals.shape == (num_datapoints, 1.0) # check shape

sampled_fn_2 = p.sample_approx(1, D, key, 100)
sampled_fn_2 = p.sample_approx(1, D, key, 100, solver_algorithm=solver_algorithm)
evals_2 = sampled_fn_2(x)
max_delta = jnp.max(jnp.abs(evals - evals_2))
assert max_delta == 0.0 # samples same for same seed

new_key = jr.key(12345)
sampled_fn_3 = p.sample_approx(1, D, new_key, 100)
sampled_fn_3 = p.sample_approx(
1, D, new_key, 100, solver_algorithm=solver_algorithm
)
evals_3 = sampled_fn_3(x)
max_delta = jnp.max(jnp.abs(evals - evals_3))
assert max_delta > 0.01 # samples different for different seed

# Check validty of samples using Monte-Carlo
sampled_fn = p.sample_approx(10_000, D, key, 100)
sampled_fn = p.sample_approx(10_000, D, key, 100, solver_algorithm=solver_algorithm)
sampled_evals = sampled_fn(x)
approx_mean = jnp.mean(sampled_evals, -1)
approx_var = jnp.var(sampled_evals, -1)
Expand Down

0 comments on commit 6eaab4a

Please sign in to comment.