-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rff sampler #195
Merged
Merged
Rff sampler #195
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,11 +17,13 @@ | |
from typing import Any, Callable, Dict, Optional | ||
|
||
import distrax as dx | ||
import jax | ||
import jax.numpy as jnp | ||
from jaxtyping import Array, Float | ||
from jax.random import KeyArray | ||
|
||
from jaxlinop import identity | ||
from jaxkern import RFF | ||
from jaxkern.base import AbstractKernel | ||
from jaxutils import PyTree | ||
|
||
|
@@ -32,6 +34,7 @@ | |
from jaxutils import Dataset | ||
from .utils import concat_dictionaries | ||
from .gaussian_distribution import GaussianDistribution | ||
from .types import FunctionalSample | ||
|
||
import deprecation | ||
|
||
|
@@ -231,9 +234,7 @@ def predict( | |
mean_function = self.mean_function | ||
kernel = self.kernel | ||
|
||
def predict_fn( | ||
test_inputs: Float[Array, "N D"] | ||
) -> GaussianDistribution: | ||
def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: | ||
|
||
# Unpack test inputs | ||
t = test_inputs | ||
|
@@ -247,6 +248,89 @@ def predict_fn( | |
|
||
return predict_fn | ||
|
||
def sample_approx( | ||
self, | ||
num_samples: int, | ||
params: Dict, | ||
key: KeyArray, | ||
num_features: Optional[int] = 100, | ||
) -> FunctionalSample: | ||
"""Build an approximate sample from the Gaussian process prior. This method | ||
provides a function that returns the evaluations of a sample across any given | ||
inputs. | ||
|
||
In particular, we approximate the Gaussian processes' prior as the finite feature | ||
approximation | ||
|
||
.. math:: \hat{f}(x) = \sum_{i=1}^m \phi_i(x)\theta_i | ||
|
||
|
||
where :math:`\phi_i` are m features sampled from the Fourier feature decomposition of | ||
the model's kernel and :math:`\theta_i` are samples from a unit Gaussian. | ||
|
||
|
||
A key property of such functional samples is that the same sample draw is | ||
evaluated for all queries. Consistency is a property that is prohibitively costly | ||
to ensure when sampling exactly from the GP prior, as the cost of exact sampling | ||
scales cubically with the size of the sample. In contrast, finite feature representations | ||
can be evaluated with constant cost regardless of the required number of queries. | ||
|
||
In the following example, we build 10 such samples | ||
and then evaluate them over the interval :math:`[0, 1]`: | ||
|
||
Example: | ||
For a ``prior`` distribution, the following code snippet will | ||
build and evaluate an approximate sample. | ||
|
||
>>> import gpjax as gpx | ||
>>> import jax.numpy as jnp | ||
>>> | ||
>>> parameter_state = gpx.initialise(prior) | ||
>>> sample_fn = prior.sample_appox(10, parameter_state.params, key) | ||
>>> sample_fn(jnp.linspace(0, 1, 100)) | ||
|
||
Args: | ||
num_samples (int): The desired number of samples. | ||
params (Dict): The specific set of parameters for which the sample | ||
should be generated for. | ||
key (KeyArray): The random seed used for the sample(s). | ||
num_features (int): The number of features used when approximating the | ||
kernel. | ||
|
||
|
||
Returns: | ||
FunctionalSample: A function representing an approximate sample from the Gaussian | ||
process prior. | ||
""" | ||
if (not isinstance(num_features, int)) or num_features <= 0: | ||
raise ValueError(f"num_features must be a positive integer") | ||
if (not isinstance(num_samples, int)) or num_samples <= 0: | ||
raise ValueError(f"num_samples must be a positive integer") | ||
|
||
approximate_kernel = RFF(self.kernel, num_features) | ||
approximate_kernel_params = approximate_kernel.init_params(key) | ||
feature_weights = jax.random.normal( | ||
key, [num_samples, 2 * num_features] | ||
) # [B, L] | ||
|
||
def sample_fn(test_inputs: Float[Array, "N D"]) -> Float[Array, "N B"]: | ||
|
||
feature_evals = ( | ||
approximate_kernel.compute_engine.compute_features( # [N, L] | ||
test_inputs, | ||
frequencies=approximate_kernel_params["frequencies"], | ||
scaling_factor=approximate_kernel_params["lengthscale"], | ||
) | ||
) | ||
feature_evals *= jnp.sqrt(params["kernel"]["variance"] / num_features) | ||
evaluated_sample = jnp.inner(feature_evals, feature_weights) # [N, B] | ||
return ( | ||
self.mean_function(params["mean_function"], test_inputs) | ||
+ evaluated_sample | ||
) | ||
|
||
return sample_fn | ||
|
||
def init_params(self, key: KeyArray) -> Dict: | ||
"""Initialise the GP prior's parameter set. | ||
|
||
|
@@ -265,6 +349,8 @@ def init_params(self, key: KeyArray) -> Dict: | |
####################### | ||
# GP Posteriors | ||
####################### | ||
|
||
|
||
class AbstractPosterior(AbstractPrior): | ||
"""The base GP posterior object conditioned on an observed dataset. All | ||
posterior objects should inherit from this class.""" | ||
|
@@ -431,7 +517,7 @@ def predict( | |
μx = mean_function(params["mean_function"], x) | ||
|
||
# Precompute Gram matrix, Kxx, at training inputs, x | ||
Kxx = kernel.gram(params["kernel"], x) | ||
Kxx = kernel.gram(params["kerrspb nel"], x) | ||
Kxx += identity(n) * jitter | ||
|
||
# Σ = Kxx + Iσ² | ||
|
@@ -466,12 +552,124 @@ def predict(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: | |
covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt) | ||
covariance += identity(n_test) * jitter | ||
|
||
return GaussianDistribution( | ||
jnp.atleast_1d(mean.squeeze()), covariance | ||
) | ||
return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) | ||
|
||
return predict | ||
|
||
def sample_approx( | ||
self, | ||
num_samples: int, | ||
params: Dict, | ||
train_data: Dataset, | ||
key: KeyArray, | ||
num_features: Optional[int] = 100, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is 100 sensible here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah actually, you dont need very many for the decoupled sampling! |
||
) -> FunctionalSample: | ||
"""Build an approximate sample from the Gaussian process posterior. This method | ||
henrymoss marked this conversation as resolved.
Show resolved
Hide resolved
|
||
provides a function that returns the evaluations of a sample across any given | ||
inputs. | ||
|
||
Unlike when building approximate samples from a Gaussian process prior, decompositions | ||
based on Fourier features alone rarely give accurate samples. Therefore, we must also | ||
include an additional set of features (known as canonical features) to better model the | ||
transition from Gaussian process prior to Gaussian process posterior. For more details | ||
see https://arxiv.org/pdf/2002.09309.pdf | ||
|
||
In particular, we approximate the Gaussian processes' posterior as the finite feature | ||
approximation | ||
|
||
.. math:: \hat{f}(x) = \sum_{i=1}^m \phi_i(x)\theta_i + \sum{j=1}^N v_jk(.,x_j) | ||
|
||
|
||
where :math:`\phi_i` are m features sampled from the Fourier feature decomposition of | ||
the model's kernel and :math:`k(., x_j)` are N canonical features. The Fourier | ||
weights :math:`\theta_i` are samples from a unit Gaussian. | ||
See https://arxiv.org/pdf/2002.09309.pdf for expressions for the canonical | ||
weights :math:`v_j`. | ||
|
||
|
||
A key property of such functional samples is that the same sample draw is | ||
evaluated for all queries. Consistency is a property that is prohibitively costly | ||
to ensure when sampling exactly from the GP prior, as the cost of exact sampling | ||
scales cubically with the size of the sample. In contrast, finite feature representations | ||
can be evaluated with constant cost regardless of the required number of queries. | ||
|
||
Args: | ||
num_samples (int): The desired number of samples. | ||
params (Dict): The specific set of parameters for which the sample | ||
should be generated for. | ||
key (KeyArray): The random seed used for the sample(s). | ||
num_features (int): The number of features used when approximating the | ||
kernel. | ||
|
||
|
||
Returns: | ||
FunctionalSample: A function representing an approximate sample from the Gaussian | ||
process prior. | ||
""" | ||
if (not isinstance(num_features, int)) or num_features <= 0: | ||
raise ValueError(f"num_features must be a positive integer") | ||
if (not isinstance(num_samples, int)) or num_samples <= 0: | ||
raise ValueError(f"num_samples must be a positive integer") | ||
|
||
# Collect required quantities | ||
jitter = get_global_config()["jitter"] | ||
obs_noise = params["likelihood"]["obs_noise"] | ||
|
||
# Approximate kernel with feature decomposition | ||
approximate_kernel = RFF(self.prior.kernel, num_features) | ||
approximate_kernel_params = approximate_kernel.init_params(key) | ||
|
||
def eval_fourier_features( | ||
test_inputs: Float[Array, "N D"] | ||
) -> Float[Array, "N L"]: | ||
Phi = approximate_kernel.compute_engine.compute_features( # [N, L] | ||
test_inputs, | ||
frequencies=approximate_kernel_params["frequencies"], | ||
scaling_factor=approximate_kernel_params["lengthscale"], | ||
) | ||
Phi *= jnp.sqrt(params["kernel"]["variance"] / num_features) | ||
return Phi | ||
|
||
# sample weights for Fourier features | ||
fourier_weights = jax.random.normal( | ||
key, [num_samples, 2 * num_features] | ||
) # [B, L] | ||
|
||
# sample weights v for canonical features | ||
# v = Σ⁻¹ (y + ε - ɸ⍵) for Σ = Kxx + Iσ² and ε ᯈ N(0, σ²) | ||
Kxx = self.prior.kernel.gram(params["kernel"], train_data.X) # [N, N] | ||
Sigma = Kxx + identity(train_data.n) * (obs_noise + jitter) # [N, N] | ||
eps = jnp.sqrt(obs_noise) * jax.random.normal( | ||
key, [train_data.n, num_samples] | ||
) # [N, B] | ||
y = train_data.y - self.prior.mean_function( | ||
params["mean_function"], train_data.X | ||
) # account for mean | ||
Phi = eval_fourier_features(train_data.X) | ||
canonical_weights = Sigma.solve( | ||
y + eps - jnp.inner(Phi, fourier_weights) | ||
) # [N, B] | ||
|
||
def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]: | ||
fourier_features = eval_fourier_features(test_inputs) | ||
weight_space_contribution = jnp.inner( | ||
fourier_features, fourier_weights | ||
) # [n, B] | ||
canonical_features = self.prior.kernel.cross_covariance( | ||
params["kernel"], test_inputs, train_data.X | ||
) # [n, N] | ||
function_space_contribution = jnp.matmul( | ||
canonical_features, canonical_weights | ||
) | ||
|
||
return ( | ||
self.prior.mean_function(params["mean_function"], test_inputs) | ||
+ weight_space_contribution | ||
+ function_space_contribution | ||
) | ||
|
||
return sample_fn | ||
|
||
def marginal_log_likelihood( | ||
self, | ||
train_data: Dataset, | ||
|
@@ -580,9 +778,7 @@ def mll( | |
) | ||
|
||
return constant * ( | ||
marginal_likelihood.log_prob( | ||
jnp.atleast_1d(y.squeeze()) | ||
).squeeze() | ||
marginal_likelihood.log_prob(jnp.atleast_1d(y.squeeze())).squeeze() | ||
) | ||
|
||
return mll | ||
|
@@ -630,9 +826,7 @@ def init_params(self, key: KeyArray) -> Dict: | |
self.prior.init_params(key), | ||
{"likelihood": self.likelihood.init_params(key)}, | ||
) | ||
parameters["latent"] = jnp.zeros( | ||
shape=(self.likelihood.num_datapoints, 1) | ||
) | ||
parameters["latent"] = jnp.zeros(shape=(self.likelihood.num_datapoints, 1)) | ||
return parameters | ||
|
||
def predict( | ||
|
@@ -704,9 +898,7 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: | |
covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) | ||
covariance += identity(n_test) * jitter | ||
|
||
return GaussianDistribution( | ||
jnp.atleast_1d(mean.squeeze()), covariance | ||
) | ||
return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance) | ||
|
||
return predict_fn | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think you'll need a raw string or double backslashes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Build
should ber"""Build