diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f9dcc1ae..63240425 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,6 @@ repos: rev: 23.3.0 hooks: - id: black - language_version: python3.8 - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: diff --git a/docs/conf.py b/docs/conf.py index c6084f94..fe0b9d34 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -59,9 +59,14 @@ def find_version(*file_paths): copyright = "2021, Thomas Pinder" author = "Thomas Pinder" +from os.path import ( + dirname, + join, + pardir, +) + # The full version, including alpha/beta/rc tags import sys -from os.path import dirname, join, pardir sys.path.insert(0, join(dirname(__file__), pardir)) diff --git a/docs/conf_sphinx_patch.py b/docs/conf_sphinx_patch.py index 7fe2b54d..714a21bb 100644 --- a/docs/conf_sphinx_patch.py +++ b/docs/conf_sphinx_patch.py @@ -1,6 +1,12 @@ # This file is credited to the Flax authors. -from typing import Any, Dict, List, Set, Tuple +from typing import ( + Any, + Dict, + List, + Set, + Tuple, +) import sphinx.ext.autodoc import sphinx.ext.autosummary.generate as ag diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 50aa5fcd..df0c50f2 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -23,7 +23,7 @@ from gpjax.likelihoods import ( Bernoulli, Gaussian, - Poisson + Poisson, ) from gpjax.mean_functions import ( Constant, diff --git a/gpjax/gps.py b/gpjax/gps.py index 35bd6af0..b0cc3a05 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -19,7 +19,6 @@ from beartype.typing import ( Any, Callable, - Dict, Optional, ) import jax.numpy as jnp @@ -269,17 +268,18 @@ def sample_approx( FunctionalSample: A function representing an approximate sample from the Gaussian process prior. """ - if (not isinstance(num_features, int)) or num_features <= 0: - raise ValueError("num_features must be a positive integer") + if (not isinstance(num_samples, int)) or num_samples <= 0: raise ValueError("num_samples must be a positive integer") - approximate_kernel = RFF(base_kernel=self.kernel, num_basis_fns=num_features) + # sample fourier features + fourier_feature_fn = _build_fourier_features_fn(self, num_features, key) + + # sample fourier weights feature_weights = normal(key, [num_samples, 2 * num_features]) # [B, L] def sample_fn(test_inputs: Float[Array, "N D"]) -> Float[Array, "N B"]: - feature_evals = approximate_kernel.compute_features(x=test_inputs) - feature_evals *= jnp.sqrt(self.kernel.variance / num_features) + feature_evals = fourier_feature_fn(test_inputs) # [N, L] evaluated_sample = jnp.inner(feature_evals, feature_weights) # [N, B] return self.mean_function(test_inputs) + evaluated_sample @@ -501,24 +501,13 @@ def sample_approx( FunctionalSample: A function representing an approximate sample from the Gaussian process prior. """ - if (not isinstance(num_features, int)) or num_features <= 0: - raise ValueError("num_features must be a positive integer") if (not isinstance(num_samples, int)) or num_samples <= 0: raise ValueError("num_samples must be a positive integer") - # Approximate kernel with feature decomposition - approximate_kernel = RFF( - base_kernel=self.prior.kernel, num_basis_fns=num_features - ) - - def eval_fourier_features( - test_inputs: Float[Array, "N D"] - ) -> Float[Array, "N L"]: - Phi = approximate_kernel.compute_features(x=test_inputs) - Phi *= jnp.sqrt(self.prior.kernel.variance / num_features) - return Phi + # sample fourier features + fourier_feature_fn = _build_fourier_features_fn(self.prior, num_features, key) - # sample weights for Fourier features + # sample fourier weights fourier_weights = normal(key, [num_samples, 2 * num_features]) # [B, L] # sample weights v for canonical features @@ -531,13 +520,13 @@ def eval_fourier_features( key, [train_data.n, num_samples] ) # [N, B] y = train_data.y - self.prior.mean_function(train_data.X) # account for mean - Phi = eval_fourier_features(train_data.X) + Phi = fourier_feature_fn(train_data.X) canonical_weights = Sigma.solve( y + eps - jnp.inner(Phi, fourier_weights) ) # [N, B] def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]: - fourier_features = eval_fourier_features(test_inputs) + fourier_features = fourier_feature_fn(test_inputs) # [n, L] weight_space_contribution = jnp.inner( fourier_features, fourier_weights ) # [n, B] @@ -634,6 +623,9 @@ def predict( return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) +####################### +# Utils +####################### def construct_posterior( prior: Prior, likelihood: AbstractLikelihood ) -> AbstractPosterior: @@ -658,6 +650,37 @@ def construct_posterior( return NonConjugatePosterior(prior=prior, likelihood=likelihood) +def _build_fourier_features_fn( + prior: Prior, num_features: int, key: KeyArray +) -> Callable[[Float[Array, "N D"]], Float[Array, "N L"]]: + """Return a function that evaluates features sampled from the Fourier feature + decomposition of the prior's kernel. + + Args: + prior (Prior): The Prior distribution. + num_features (int): The number of feature functions to be sampled. + key (KeyArray): The random seed used. + + Returns + ------- + Callable: A callable function evaluation the sampled feature functions. + """ + if (not isinstance(num_features, int)) or num_features <= 0: + raise ValueError("num_features must be a positive integer") + + # Approximate kernel with feature decomposition + approximate_kernel = RFF( + base_kernel=prior.kernel, num_basis_fns=num_features, key=key + ) + + def eval_fourier_features(test_inputs: Float[Array, "N D"]) -> Float[Array, "N L"]: + Phi = approximate_kernel.compute_features(x=test_inputs) + Phi *= jnp.sqrt(prior.kernel.variance / num_features) + return Phi + + return eval_fourier_features + + __all__ = [ "AbstractPrior", "Prior", diff --git a/tests/test_gps.py b/tests/test_gps.py index f17790ed..ba378b7d 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -21,7 +21,6 @@ ValidationErrors = ValueError from dataclasses import is_dataclass -import shutil from typing import Callable from jax.config import config @@ -31,11 +30,6 @@ import pytest import tensorflow_probability.substrates.jax.distributions as tfd -from gpjax.base import ( - load_tree, - save_tree, -) - # from gpjax.dataset import Dataset from gpjax.dataset import Dataset from gpjax.gaussian_distribution import GaussianDistribution @@ -50,15 +44,13 @@ from gpjax.kernels import ( RBF, AbstractKernel, - Matern12, - Matern32, Matern52, ) from gpjax.likelihoods import ( AbstractLikelihood, Bernoulli, Gaussian, - Poisson + Poisson, ) from gpjax.mean_functions import ( AbstractMeanFunction, @@ -272,7 +264,7 @@ def test_posterior_construct( @pytest.mark.parametrize("kernel", [RBF, Matern52]) @pytest.mark.parametrize("mean_function", [Zero(), Constant()]) def test_prior_sample_approx(num_datapoints, kernel, mean_function): - kern = kernel(lengthscale=5.0, variance=0.1) + kern = kernel(lengthscale=jnp.array([5.0, 1.0]), variance=0.1) p = Prior(kernel=kern, mean_function=mean_function) key = jr.PRNGKey(123) @@ -292,7 +284,7 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function): sampled_fn = p.sample_approx(1, key, 100) assert isinstance(sampled_fn, Callable) # check type - x = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) + x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2)) evals = sampled_fn(x) assert evals.shape == (num_datapoints, 1.0) # check shape @@ -325,16 +317,17 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function): @pytest.mark.parametrize("kernel", [RBF, Matern52]) @pytest.mark.parametrize("mean_function", [Zero(), Constant()]) def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function): - kern = kernel(lengthscale=5.0, variance=0.1) + kern = kernel(lengthscale=jnp.array([5.0, 1.0]), variance=0.1) p = Prior(kernel=kern, mean_function=mean_function) * Gaussian( num_datapoints=num_datapoints ) key = jr.PRNGKey(123) - x = jnp.sort( - jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 1)), - axis=0, + + x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2)) + y = ( + jnp.mean(jnp.sin(x), 1, keepdims=True) + + jr.normal(key=key, shape=(num_datapoints, 1)) * 0.1 ) - y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 D = Dataset(X=x, y=y) with pytest.raises(ValueError): @@ -353,7 +346,7 @@ def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function sampled_fn = p.sample_approx(1, D, key, 100) assert isinstance(sampled_fn, Callable) # check type - x = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) + x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2)) evals = sampled_fn(x) assert evals.shape == (num_datapoints, 1.0) # check shape diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index 14d98549..c5b27d17 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -200,7 +200,6 @@ class TestPoisson(BaseTestLikelihood): def _test_call_check( likelihood: AbstractLikelihood, latent_mean, latent_cov, latent_dist ): - # Test call method. pred_dist = likelihood(latent_dist)