From 4bf4d978d42804d7eef86c9a9e27269efe2b2b63 Mon Sep 17 00:00:00 2001 From: henrymoss Date: Tue, 21 Mar 2023 14:23:54 +0000 Subject: [PATCH 1/4] wip --- gpjax/gps.py | 89 +++++++++++++++++++++++++++++++++++++++++ gpjax/mean_functions.py | 5 ++- tests/test_gps.py | 59 ++++++++++++++++++++++++++- 3 files changed, 150 insertions(+), 3 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index 7d72b2ff..1731a97f 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -17,11 +17,13 @@ from typing import Any, Callable, Dict, Optional import distrax as dx +import jax import jax.numpy as jnp from jaxtyping import Array, Float from jax.random import KeyArray from jaxlinop import identity +from jaxkern import RFF from jaxkern.base import AbstractKernel from jaxutils import PyTree @@ -101,6 +103,14 @@ def _initialise_params(self, key: KeyArray) -> Dict: return self.init_params(key) +FunctionalSample = Callable[[Float[Array, "N D"]], Float[Array, "N B"]] +""" Type alias for functions representing `B` samples from a model, to be evaluated on any set of +`N` inputs (of dimension `D`) and returning the evaluations of each (potentially approximate) +sample draw across these inputs. +""" + + + ####################### # GP Priors ####################### @@ -247,6 +257,84 @@ def predict_fn( return predict_fn + def sample_approx( + self, + num_samples: int, + params: Dict, + seed: KeyArray, + num_features: Optional[int]=100, + ) -> FunctionalSample: + """Build an approximate sample from the Gaussian process prior. This method + provides a function that returns the evaluations of a sample across any given + inputs. + + In particular, we approximate the Gaussian processes' prior as the finite feature + approximation + + .. math:: \hat{f}(x) = \sum_{i=1}^m \phi_i(x)\theta_i + + + where :math:`\phi_i` are m features sampled from the Fourier feature decomposition of + the model's kernel and :math:`\theta_i` are samples from a unit Gaussian. + + + A key property of such functional samples is that the same sample draw is + evaluated for all queries. Consistency is a property that is prohibitively costly + to ensure when sampling exactly from the GP prior, as the cost of exact sampling + scales cubically with the size of the sample. In contrast, finite feature representations + can be evaluated with constant cost regardless of the required number of queries. + + In the following example, we build 10 such samples + and then evaluate them over the interval :math:`[0, 1]`: + + Example: + >>> import gpjax as gpx + >>> import jax.numpy as jnp + >>> + >>> kernel = gpx.kernels.RBF() + >>> prior = gpx.Prior(kernel = kernel) + >>> seed = jr.PRNGKey(123) + >>> + >>> parameter_state = gpx.initialise(prior) + >>> sample_fn = prior.sample_appox(10, parameter_state.params, seed) + >>> sample_fn(jnp.linspace(0, 1, 100)) + + Args: + num_samples (int): The desired number of samples. + params (Dict): The specific set of parameters for which the sample + should be generated for. + seed (KeyArray): The random seed used for the sample(s). + num_features (int): The number of features used when approximating the + kernel. + + + Returns: + FunctionalSample: A function representing an approximate sample from the Gaussian + process prior. + """ + for integer_input in [num_features, num_samples]: + if (not isinstance(integer_input,int)) or integer_input<0: + raise ValueError + + approximate_kernel = RFF(self.kernel, num_features) + approximate_kernel_params = approximate_kernel.init_params(seed) + feature_weights = jax.random.normal(seed, [num_samples, 2*num_features]) # [B, L] + + def sample_fn(test_inputs: Float[Array, "N D"] + ) -> Float[Array, "N B"]: + + feature_evals = approximate_kernel.compute_engine.compute_features( # [N, L] + test_inputs, + frequencies=approximate_kernel_params["frequencies"], + scaling_factor=approximate_kernel_params["lengthscale"], + ) + feature_evals *= jnp.sqrt(params["kernel"]["variance"] / num_features) + evaluated_sample = jnp.inner(feature_evals,feature_weights) # [N, B] + return self.mean_function(params["mean_function"], test_inputs) + evaluated_sample + + return sample_fn + + def init_params(self, key: KeyArray) -> Dict: """Initialise the GP prior's parameter set. @@ -262,6 +350,7 @@ def init_params(self, key: KeyArray) -> Dict: } + ####################### # GP Posteriors ####################### diff --git a/gpjax/mean_functions.py b/gpjax/mean_functions.py index 75ebf43e..93720d43 100644 --- a/gpjax/mean_functions.py +++ b/gpjax/mean_functions.py @@ -117,8 +117,9 @@ def init_params(self, key: KeyArray) -> Dict: class Constant(AbstractMeanFunction): """ - A zero mean function. This function returns a repeated scalar value for all inputs. - The scalar value itself can be treated as a model hyperparameter and learned during training. + A constant mean function. This function returns a repeated scalar value for all inputs. + The scalar value itself can be treated as a model hyperparameter and learned during training but + defaults to 1.0. """ def __init__( diff --git a/tests/test_gps.py b/tests/test_gps.py index 9bab15d8..2a66a906 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -31,7 +31,8 @@ Prior, construct_posterior, ) -from gpjax.kernels import RBF, Matern12, Matern32, Matern52 +from gpjax.mean_functions import Zero, Constant +from jaxkern import RBF, Matern12, Matern32, Matern52 from gpjax.likelihoods import Bernoulli, Gaussian from gpjax.parameters import ParameterState @@ -59,6 +60,58 @@ def test_prior(num_datapoints): assert sigma.shape == (num_datapoints, num_datapoints) +@pytest.mark.parametrize("num_datapoints", [1, 5]) +@pytest.mark.parametrize("kernel", [RBF(), Matern52()]) +@pytest.mark.parametrize("mean_function", [Zero(), Constant()]) +def test_prior_sample_approx(num_datapoints, kernel, mean_function): + p = Prior(kernel=kernel, mean_function=mean_function) + key = jr.PRNGKey(123) + parameter_state = initialise(p, key) + params, _, _ = parameter_state.unpack() + params["kernel"]["lengthscale"]=5.0 + params["kernel"]["variance"]=0.1 + + with pytest.raises(ValueError): + p.sample_approx(-1,params, key) + with pytest.raises(ValueError): + p.sample_approx(0.5,params, key) + with pytest.raises(ValueError): + p.sample_approx(1,params, key, -10) + with pytest.raises(ValueError): + p.sample_approx(1,params, key, 0.5) + + sampled_fn = p.sample_approx(1,params, key, 100) + assert isinstance(sampled_fn, tp.Callable) # check type + + x = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) + evals = sampled_fn(x) + assert evals.shape == (num_datapoints, 1.0) # check shape + + sampled_fn_2 = p.sample_approx(1,params, key, 100) + evals_2 = sampled_fn_2(x) + max_delta = jnp.max(jnp.abs(evals - evals_2)) + assert max_delta == 0.0 # samples same for same seed + + new_key = jr.PRNGKey(12345) + sampled_fn_3 = p.sample_approx(1,params, new_key, 100) + evals_3 = sampled_fn_3(x) + max_delta = jnp.max(jnp.abs(evals - evals_3)) + assert max_delta > 0.1 # samples different for different seed + + # Check validty of samples using Monte-Carlo + sampled_fn = p.sample_approx(10_000,params, key, 100) + sampled_evals = sampled_fn(x) + approx_mean = jnp.mean(sampled_evals, -1) + approx_var = jnp.var(sampled_evals, -1) + true_predictive = p(params)(x) + true_mean = true_predictive.mean() + true_var = jnp.diagonal(true_predictive.covariance()) + max_error_in_mean = jnp.max(jnp.abs(approx_mean - true_mean)) + max_error_in_var = jnp.max(jnp.abs(approx_var - true_var)) + assert max_error_in_mean < 0.02 # check that samples are correct + assert max_error_in_var < 0.05 # check that samples are correct + + @pytest.mark.parametrize("num_datapoints", [1, 2, 10]) def test_conjugate_posterior(num_datapoints): key = jr.PRNGKey(123) @@ -209,3 +262,7 @@ def test_initialisation_override(kernel): with pytest.raises(ValueError): parameter_state = initialise(p, key, keernel=override_params) + + + + From 6ccb2c732ca30459303dfea6a2fb825ac5c0cb34 Mon Sep 17 00:00:00 2001 From: henrymoss Date: Tue, 21 Mar 2023 14:47:15 +0000 Subject: [PATCH 2/4] format --- gpjax/gps.py | 67 +++++++++++++++++++++++----------------------------- 1 file changed, 30 insertions(+), 37 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index 1731a97f..abf24162 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -110,7 +110,6 @@ def _initialise_params(self, key: KeyArray) -> Dict: """ - ####################### # GP Priors ####################### @@ -241,9 +240,7 @@ def predict( mean_function = self.mean_function kernel = self.kernel - def predict_fn( - test_inputs: Float[Array, "N D"] - ) -> GaussianDistribution: + def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: # Unpack test inputs t = test_inputs @@ -258,15 +255,15 @@ def predict_fn( return predict_fn def sample_approx( - self, + self, num_samples: int, params: Dict, - seed: KeyArray, - num_features: Optional[int]=100, + seed: KeyArray, + num_features: Optional[int] = 100, ) -> FunctionalSample: """Build an approximate sample from the Gaussian process prior. This method - provides a function that returns the evaluations of a sample across any given - inputs. + provides a function that returns the evaluations of a sample across any given + inputs. In particular, we approximate the Gaussian processes' prior as the finite feature approximation @@ -304,7 +301,7 @@ def sample_approx( params (Dict): The specific set of parameters for which the sample should be generated for. seed (KeyArray): The random seed used for the sample(s). - num_features (int): The number of features used when approximating the + num_features (int): The number of features used when approximating the kernel. @@ -313,28 +310,33 @@ def sample_approx( process prior. """ for integer_input in [num_features, num_samples]: - if (not isinstance(integer_input,int)) or integer_input<0: + if (not isinstance(integer_input, int)) or integer_input < 0: raise ValueError approximate_kernel = RFF(self.kernel, num_features) approximate_kernel_params = approximate_kernel.init_params(seed) - feature_weights = jax.random.normal(seed, [num_samples, 2*num_features]) # [B, L] - - def sample_fn(test_inputs: Float[Array, "N D"] - ) -> Float[Array, "N B"]: - - feature_evals = approximate_kernel.compute_engine.compute_features( # [N, L] - test_inputs, - frequencies=approximate_kernel_params["frequencies"], - scaling_factor=approximate_kernel_params["lengthscale"], - ) + feature_weights = jax.random.normal( + seed, [num_samples, 2 * num_features] + ) # [B, L] + + def sample_fn(test_inputs: Float[Array, "N D"]) -> Float[Array, "N B"]: + + feature_evals = ( + approximate_kernel.compute_engine.compute_features( # [N, L] + test_inputs, + frequencies=approximate_kernel_params["frequencies"], + scaling_factor=approximate_kernel_params["lengthscale"], + ) + ) feature_evals *= jnp.sqrt(params["kernel"]["variance"] / num_features) - evaluated_sample = jnp.inner(feature_evals,feature_weights) # [N, B] - return self.mean_function(params["mean_function"], test_inputs) + evaluated_sample + evaluated_sample = jnp.inner(feature_evals, feature_weights) # [N, B] + return ( + self.mean_function(params["mean_function"], test_inputs) + + evaluated_sample + ) return sample_fn - def init_params(self, key: KeyArray) -> Dict: """Initialise the GP prior's parameter set. @@ -350,7 +352,6 @@ def init_params(self, key: KeyArray) -> Dict: } - ####################### # GP Posteriors ####################### @@ -555,9 +556,7 @@ def predict(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt) covariance += identity(n_test) * jitter - return GaussianDistribution( - jnp.atleast_1d(mean.squeeze()), covariance - ) + return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) return predict @@ -669,9 +668,7 @@ def mll( ) return constant * ( - marginal_likelihood.log_prob( - jnp.atleast_1d(y.squeeze()) - ).squeeze() + marginal_likelihood.log_prob(jnp.atleast_1d(y.squeeze())).squeeze() ) return mll @@ -719,9 +716,7 @@ def init_params(self, key: KeyArray) -> Dict: self.prior.init_params(key), {"likelihood": self.likelihood.init_params(key)}, ) - parameters["latent"] = jnp.zeros( - shape=(self.likelihood.num_datapoints, 1) - ) + parameters["latent"] = jnp.zeros(shape=(self.likelihood.num_datapoints, 1)) return parameters def predict( @@ -793,9 +788,7 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) covariance += identity(n_test) * jitter - return GaussianDistribution( - jnp.atleast_1d(mean.squeeze()), covariance - ) + return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) return predict_fn From e8062dac83f2a637863119da6320367ecd26e19f Mon Sep 17 00:00:00 2001 From: henrymoss Date: Wed, 29 Mar 2023 16:48:38 +0100 Subject: [PATCH 3/4] added decoupled --- examples/yacht.pct.py | 4 +- gpjax/gps.py | 146 ++++++++++++++++++++++++++++++++++----- gpjax/types.py | 9 +++ tests/test_gps.py | 108 +++++++++++++++++++++++------ tests/test_parameters.py | 1 - 5 files changed, 223 insertions(+), 45 deletions(-) diff --git a/examples/yacht.pct.py b/examples/yacht.pct.py index a56fd2b0..2fd9abc4 100644 --- a/examples/yacht.pct.py +++ b/examples/yacht.pct.py @@ -196,9 +196,7 @@ ax[1].scatter(predictive_mean.squeeze(), residuals) ax[1].plot([0, 1], [0.5, 0.5], color="tab:orange", transform=ax[1].transAxes) ax[1].set_ylim([-1.0, 1.0]) -ax[1].set( - xlabel="Predicted", ylabel="Residuals", title="Predicted vs Residuals" -) +ax[1].set(xlabel="Predicted", ylabel="Residuals", title="Predicted vs Residuals") ax[2].hist(np.asarray(residuals), bins=30) ax[2].set_title("Residuals") diff --git a/gpjax/gps.py b/gpjax/gps.py index abf24162..3eef708a 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -34,6 +34,7 @@ from jaxutils import Dataset from .utils import concat_dictionaries from .gaussian_distribution import GaussianDistribution +from .types import FunctionalSample import deprecation @@ -103,13 +104,6 @@ def _initialise_params(self, key: KeyArray) -> Dict: return self.init_params(key) -FunctionalSample = Callable[[Float[Array, "N D"]], Float[Array, "N B"]] -""" Type alias for functions representing `B` samples from a model, to be evaluated on any set of -`N` inputs (of dimension `D`) and returning the evaluations of each (potentially approximate) -sample draw across these inputs. -""" - - ####################### # GP Priors ####################### @@ -258,7 +252,7 @@ def sample_approx( self, num_samples: int, params: Dict, - seed: KeyArray, + key: KeyArray, num_features: Optional[int] = 100, ) -> FunctionalSample: """Build an approximate sample from the Gaussian process prior. This method @@ -285,22 +279,21 @@ def sample_approx( and then evaluate them over the interval :math:`[0, 1]`: Example: + For a ``prior`` distribution, the following code snippet will + build and evaluate an approximate sample. + >>> import gpjax as gpx >>> import jax.numpy as jnp >>> - >>> kernel = gpx.kernels.RBF() - >>> prior = gpx.Prior(kernel = kernel) - >>> seed = jr.PRNGKey(123) - >>> >>> parameter_state = gpx.initialise(prior) - >>> sample_fn = prior.sample_appox(10, parameter_state.params, seed) + >>> sample_fn = prior.sample_appox(10, parameter_state.params, key) >>> sample_fn(jnp.linspace(0, 1, 100)) Args: num_samples (int): The desired number of samples. params (Dict): The specific set of parameters for which the sample should be generated for. - seed (KeyArray): The random seed used for the sample(s). + key (KeyArray): The random seed used for the sample(s). num_features (int): The number of features used when approximating the kernel. @@ -309,14 +302,15 @@ def sample_approx( FunctionalSample: A function representing an approximate sample from the Gaussian process prior. """ - for integer_input in [num_features, num_samples]: - if (not isinstance(integer_input, int)) or integer_input < 0: - raise ValueError + if (not isinstance(num_features, int)) or num_features <= 0: + raise ValueError(f"num_features must be a positive integer") + if (not isinstance(num_samples, int)) or num_samples <= 0: + raise ValueError(f"num_samples must be a positive integer") approximate_kernel = RFF(self.kernel, num_features) - approximate_kernel_params = approximate_kernel.init_params(seed) + approximate_kernel_params = approximate_kernel.init_params(key) feature_weights = jax.random.normal( - seed, [num_samples, 2 * num_features] + key, [num_samples, 2 * num_features] ) # [B, L] def sample_fn(test_inputs: Float[Array, "N D"]) -> Float[Array, "N B"]: @@ -355,6 +349,8 @@ def init_params(self, key: KeyArray) -> Dict: ####################### # GP Posteriors ####################### + + class AbstractPosterior(AbstractPrior): """The base GP posterior object conditioned on an observed dataset. All posterior objects should inherit from this class.""" @@ -560,6 +556,118 @@ def predict(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: return predict + def sample_approx( + self, + num_samples: int, + params: Dict, + train_data: Dataset, + key: KeyArray, + num_features: Optional[int] = 100, + ) -> FunctionalSample: + """Build an approximate sample from the Gaussian process posterior. This method + provides a function that returns the evaluations of a sample across any given + inputs. + + Unlike when building approximate samples from a Gaussian process prior, decompositions + based on Fourier features alone rarely give accurate samples. Therefore, we must also + include an additional set of features (known as canonical features) to better model the + transition from Gaussian process prior to Gaussian process posterior. + + In particular, we approximate the Gaussian processes' posterior as the finite feature + approximation + + .. math:: \hat{f}(x) = \sum_{i=1}^m \phi_i(x)\theta_i + \sum{j=1}^N v_jk(.,x_j) + + + where :math:`\phi_i` are m features sampled from the Fourier feature decomposition of + the model's kernel and :math:`k(., x_j)` are N canonical features. The Fourier + weights :math:`\theta_i` are samples from a unit Gaussian. See REF for expressions for + the canonical weights :math:`v_j`. + + + A key property of such functional samples is that the same sample draw is + evaluated for all queries. Consistency is a property that is prohibitively costly + to ensure when sampling exactly from the GP prior, as the cost of exact sampling + scales cubically with the size of the sample. In contrast, finite feature representations + can be evaluated with constant cost regardless of the required number of queries. + + Args: + num_samples (int): The desired number of samples. + params (Dict): The specific set of parameters for which the sample + should be generated for. + key (KeyArray): The random seed used for the sample(s). + num_features (int): The number of features used when approximating the + kernel. + + + Returns: + FunctionalSample: A function representing an approximate sample from the Gaussian + process prior. + """ + if (not isinstance(num_features, int)) or num_features <= 0: + raise ValueError(f"num_features must be a positive integer") + if (not isinstance(num_samples, int)) or num_samples <= 0: + raise ValueError(f"num_samples must be a positive integer") + + # Collect required quantities + jitter = get_global_config()["jitter"] + obs_noise = params["likelihood"]["obs_noise"] + + # Approximate kernel with feature decomposition + approximate_kernel = RFF(self.prior.kernel, num_features) + approximate_kernel_params = approximate_kernel.init_params(key) + + def eval_fourier_features( + test_inputs: Float[Array, "N D"] + ) -> Float[Array, "N L"]: + Phi = approximate_kernel.compute_engine.compute_features( # [N, L] + test_inputs, + frequencies=approximate_kernel_params["frequencies"], + scaling_factor=approximate_kernel_params["lengthscale"], + ) + Phi *= jnp.sqrt(params["kernel"]["variance"] / num_features) + return Phi + + # sample weights for Fourier features + fourier_weights = jax.random.normal( + key, [num_samples, 2 * num_features] + ) # [B, L] + + # sample weights v for canonical features + # v = Σ⁻¹ (y + ε - ɸ⍵) for Σ = Kxx + Iσ² and ε ᯈ N(0, σ²) + Kxx = self.prior.kernel.gram(params["kernel"], train_data.X) # [N, N] + Sigma = Kxx + identity(train_data.n) * (obs_noise + jitter) # [N, N] + eps = jnp.sqrt(obs_noise) * jax.random.normal( + key, [train_data.n, num_samples] + ) # [N, B] + y = train_data.y - self.prior.mean_function( + params["mean_function"], train_data.X + ) # account for mean + Phi = eval_fourier_features(train_data.X) + canonical_weights = Sigma.solve( + y + eps - jnp.inner(Phi, fourier_weights) + ) # [N, B] + + def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]: + fourier_features = eval_fourier_features(test_inputs) + weight_space_contribution = jnp.inner( + fourier_features, fourier_weights + ) # [n, B] + canonical_features = self.prior.kernel.cross_covariance( + params["kernel"], test_inputs, train_data.X + ) # [n, N] + function_space_contribution = jnp.matmul( + canonical_features, canonical_weights + ) + + return ( + self.prior.mean_function(params["mean_function"], test_inputs) + + weight_space_contribution + + function_space_contribution + ) + + return sample_fn + def marginal_log_likelihood( self, train_data: Dataset, diff --git a/gpjax/types.py b/gpjax/types.py index d1e8b110..d8fddd13 100644 --- a/gpjax/types.py +++ b/gpjax/types.py @@ -15,6 +15,8 @@ import jaxutils import deprecation +from typing import Callable +from jaxtyping import Array, Float Dataset = deprecation.deprecated( deprecated_in="0.5.5", @@ -30,3 +32,10 @@ __all__ = ["Dataset" "verify_dataset"] + + +FunctionalSample = Callable[[Float[Array, "N D"]], Float[Array, "N B"]] +""" Type alias for functions representing `B` samples from a model, to be evaluated on any set of +`N` inputs (of dimension `D`) and returning the evaluations of each (potentially approximate) +sample draw across these inputs. +""" diff --git a/tests/test_gps.py b/tests/test_gps.py index 2a66a906..6e4dc4ab 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -68,48 +68,52 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function): key = jr.PRNGKey(123) parameter_state = initialise(p, key) params, _, _ = parameter_state.unpack() - params["kernel"]["lengthscale"]=5.0 - params["kernel"]["variance"]=0.1 - + params["kernel"]["lengthscale"] = 5.0 + params["kernel"]["variance"] = 0.1 + + with pytest.raises(ValueError): + p.sample_approx(-1, params, key) + with pytest.raises(ValueError): + p.sample_approx(0, params, key) with pytest.raises(ValueError): - p.sample_approx(-1,params, key) + p.sample_approx(0.5, params, key) with pytest.raises(ValueError): - p.sample_approx(0.5,params, key) + p.sample_approx(1, params, key, -10) with pytest.raises(ValueError): - p.sample_approx(1,params, key, -10) + p.sample_approx(1, params, key, 0) with pytest.raises(ValueError): - p.sample_approx(1,params, key, 0.5) + p.sample_approx(1, params, key, 0.5) - sampled_fn = p.sample_approx(1,params, key, 100) - assert isinstance(sampled_fn, tp.Callable) # check type + sampled_fn = p.sample_approx(1, params, key, 100) + assert isinstance(sampled_fn, tp.Callable) # check type x = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) evals = sampled_fn(x) - assert evals.shape == (num_datapoints, 1.0) # check shape + assert evals.shape == (num_datapoints, 1.0) # check shape - sampled_fn_2 = p.sample_approx(1,params, key, 100) + sampled_fn_2 = p.sample_approx(1, params, key, 100) evals_2 = sampled_fn_2(x) max_delta = jnp.max(jnp.abs(evals - evals_2)) - assert max_delta == 0.0 # samples same for same seed + assert max_delta == 0.0 # samples same for same seed new_key = jr.PRNGKey(12345) - sampled_fn_3 = p.sample_approx(1,params, new_key, 100) + sampled_fn_3 = p.sample_approx(1, params, new_key, 100) evals_3 = sampled_fn_3(x) max_delta = jnp.max(jnp.abs(evals - evals_3)) - assert max_delta > 0.1 # samples different for different seed + assert max_delta > 0.1 # samples different for different seed # Check validty of samples using Monte-Carlo - sampled_fn = p.sample_approx(10_000,params, key, 100) + sampled_fn = p.sample_approx(10_000, params, key, 100) sampled_evals = sampled_fn(x) - approx_mean = jnp.mean(sampled_evals, -1) + approx_mean = jnp.mean(sampled_evals, -1) approx_var = jnp.var(sampled_evals, -1) true_predictive = p(params)(x) true_mean = true_predictive.mean() true_var = jnp.diagonal(true_predictive.covariance()) max_error_in_mean = jnp.max(jnp.abs(approx_mean - true_mean)) max_error_in_var = jnp.max(jnp.abs(approx_var - true_var)) - assert max_error_in_mean < 0.02 # check that samples are correct - assert max_error_in_var < 0.05 # check that samples are correct + assert max_error_in_mean < 0.02 # check that samples are correct + assert max_error_in_var < 0.05 # check that samples are correct @pytest.mark.parametrize("num_datapoints", [1, 2, 10]) @@ -156,6 +160,70 @@ def test_conjugate_posterior(num_datapoints): assert sigma.shape == (num_datapoints, num_datapoints) +@pytest.mark.parametrize("num_datapoints", [1, 5]) +@pytest.mark.parametrize("kernel", [RBF(), Matern52()]) +@pytest.mark.parametrize("mean_function", [Zero(), Constant()]) +def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function): + p = Prior(kernel=kernel, mean_function=mean_function) * Gaussian( + num_datapoints=num_datapoints + ) + key = jr.PRNGKey(123) + x = jnp.sort( + jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 1)), + axis=0, + ) + y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 + D = Dataset(X=x, y=y) + parameter_state = initialise(p, key) + params, _, _ = parameter_state.unpack() + params["kernel"]["lengthscale"] = 5.0 + params["kernel"]["variance"] = 0.1 + + with pytest.raises(ValueError): + p.sample_approx(-1, params, D, key) + with pytest.raises(ValueError): + p.sample_approx(0, params, D, key) + with pytest.raises(ValueError): + p.sample_approx(0.5, params, D, key) + with pytest.raises(ValueError): + p.sample_approx(1, params, D, key, -10) + with pytest.raises(ValueError): + p.sample_approx(1, params, D, key, 0) + with pytest.raises(ValueError): + p.sample_approx(1, params, D, key, 0.5) + + sampled_fn = p.sample_approx(1, params, D, key, 100) + assert isinstance(sampled_fn, tp.Callable) # check type + + x = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) + evals = sampled_fn(x) + assert evals.shape == (num_datapoints, 1.0) # check shape + + sampled_fn_2 = p.sample_approx(1, params, D, key, 100) + evals_2 = sampled_fn_2(x) + max_delta = jnp.max(jnp.abs(evals - evals_2)) + assert max_delta == 0.0 # samples same for same seed + + new_key = jr.PRNGKey(12345) + sampled_fn_3 = p.sample_approx(1, params, D, new_key, 100) + evals_3 = sampled_fn_3(x) + max_delta = jnp.max(jnp.abs(evals - evals_3)) + assert max_delta > 0.01 # samples different for different seed + + # Check validty of samples using Monte-Carlo + sampled_fn = p.sample_approx(10_000, params, D, key, 100) + sampled_evals = sampled_fn(x) + approx_mean = jnp.mean(sampled_evals, -1) + approx_var = jnp.var(sampled_evals, -1) + true_predictive = p(params, D)(x) + true_mean = true_predictive.mean() + true_var = jnp.diagonal(true_predictive.covariance()) + max_error_in_mean = jnp.max(jnp.abs(approx_mean - true_mean)) + max_error_in_var = jnp.max(jnp.abs(approx_var - true_var)) + assert max_error_in_mean < 0.02 # check that samples are correct + assert max_error_in_var < 0.05 # check that samples are correct + + @pytest.mark.parametrize("num_datapoints", [1, 2, 10]) @pytest.mark.parametrize("likel", NonConjugateLikelihoods) def test_nonconjugate_posterior(num_datapoints, likel): @@ -262,7 +330,3 @@ def test_initialisation_override(kernel): with pytest.raises(ValueError): parameter_state = initialise(p, key, keernel=override_params) - - - - diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 94ed622f..793a9fe7 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -26,7 +26,6 @@ from gpjax.likelihoods import Bernoulli, Gaussian from gpjax.parameters import ( build_bijectors, - build_trainables, constrain, copy_dict_structure, evaluate_priors, From c6b1d745e42f68c68315e488ee6482b6b8debce2 Mon Sep 17 00:00:00 2001 From: henrymoss Date: Wed, 29 Mar 2023 17:16:24 +0100 Subject: [PATCH 4/4] 4majortom --- examples/regression.pct.py | 8 +++----- gpjax/gps.py | 10 ++++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/regression.pct.py b/examples/regression.pct.py index 5d0090c7..33e25686 100644 --- a/examples/regression.pct.py +++ b/examples/regression.pct.py @@ -6,11 +6,11 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.11.2 +# jupytext_version: 1.6.0 # kernelspec: -# display_name: base +# display_name: gpjax # language: python -# name: python3 +# name: gpjax # --- # %% [markdown] @@ -243,8 +243,6 @@ xtest, ytest, label="Latent function", color="black", linestyle="--", linewidth=1 ) -ax.legend() - # %% [markdown] # ## System configuration diff --git a/gpjax/gps.py b/gpjax/gps.py index 3eef708a..d3aa87ac 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -517,7 +517,7 @@ def predict( μx = mean_function(params["mean_function"], x) # Precompute Gram matrix, Kxx, at training inputs, x - Kxx = kernel.gram(params["kernel"], x) + Kxx = kernel.gram(params["kerrspb nel"], x) Kxx += identity(n) * jitter # Σ = Kxx + Iσ² @@ -571,7 +571,8 @@ def sample_approx( Unlike when building approximate samples from a Gaussian process prior, decompositions based on Fourier features alone rarely give accurate samples. Therefore, we must also include an additional set of features (known as canonical features) to better model the - transition from Gaussian process prior to Gaussian process posterior. + transition from Gaussian process prior to Gaussian process posterior. For more details + see https://arxiv.org/pdf/2002.09309.pdf In particular, we approximate the Gaussian processes' posterior as the finite feature approximation @@ -581,8 +582,9 @@ def sample_approx( where :math:`\phi_i` are m features sampled from the Fourier feature decomposition of the model's kernel and :math:`k(., x_j)` are N canonical features. The Fourier - weights :math:`\theta_i` are samples from a unit Gaussian. See REF for expressions for - the canonical weights :math:`v_j`. + weights :math:`\theta_i` are samples from a unit Gaussian. + See https://arxiv.org/pdf/2002.09309.pdf for expressions for the canonical + weights :math:`v_j`. A key property of such functional samples is that the same sample draw is