Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rff sampler #195

Merged
merged 4 commits into from
Apr 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions examples/regression.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.11.2
# jupytext_version: 1.6.0
# kernelspec:
# display_name: base
# display_name: gpjax
# language: python
# name: python3
# name: gpjax
# ---

# %% [markdown]
Expand Down Expand Up @@ -243,8 +243,6 @@
xtest, ytest, label="Latent function", color="black", linestyle="--", linewidth=1
)

ax.legend()

# %% [markdown]
# ## System configuration

Expand Down
4 changes: 1 addition & 3 deletions examples/yacht.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,7 @@
ax[1].scatter(predictive_mean.squeeze(), residuals)
ax[1].plot([0, 1], [0.5, 0.5], color="tab:orange", transform=ax[1].transAxes)
ax[1].set_ylim([-1.0, 1.0])
ax[1].set(
xlabel="Predicted", ylabel="Residuals", title="Predicted vs Residuals"
)
ax[1].set(xlabel="Predicted", ylabel="Residuals", title="Predicted vs Residuals")

ax[2].hist(np.asarray(residuals), bins=30)
ax[2].set_title("Residuals")
Expand Down
224 changes: 208 additions & 16 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
from typing import Any, Callable, Dict, Optional

import distrax as dx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float
from jax.random import KeyArray

from jaxlinop import identity
from jaxkern import RFF
from jaxkern.base import AbstractKernel
from jaxutils import PyTree

Expand All @@ -32,6 +34,7 @@
from jaxutils import Dataset
from .utils import concat_dictionaries
from .gaussian_distribution import GaussianDistribution
from .types import FunctionalSample

import deprecation

Expand Down Expand Up @@ -231,9 +234,7 @@ def predict(
mean_function = self.mean_function
kernel = self.kernel

def predict_fn(
test_inputs: Float[Array, "N D"]
) -> GaussianDistribution:
def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution:

# Unpack test inputs
t = test_inputs
Expand All @@ -247,6 +248,89 @@ def predict_fn(

return predict_fn

def sample_approx(
self,
num_samples: int,
params: Dict,
key: KeyArray,
num_features: Optional[int] = 100,
) -> FunctionalSample:
"""Build an approximate sample from the Gaussian process prior. This method
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think you'll need a raw string or double backslashes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Build should be r"""Build

provides a function that returns the evaluations of a sample across any given
inputs.

In particular, we approximate the Gaussian processes' prior as the finite feature
approximation

.. math:: \hat{f}(x) = \sum_{i=1}^m \phi_i(x)\theta_i


where :math:`\phi_i` are m features sampled from the Fourier feature decomposition of
the model's kernel and :math:`\theta_i` are samples from a unit Gaussian.


A key property of such functional samples is that the same sample draw is
evaluated for all queries. Consistency is a property that is prohibitively costly
to ensure when sampling exactly from the GP prior, as the cost of exact sampling
scales cubically with the size of the sample. In contrast, finite feature representations
can be evaluated with constant cost regardless of the required number of queries.

In the following example, we build 10 such samples
and then evaluate them over the interval :math:`[0, 1]`:

Example:
For a ``prior`` distribution, the following code snippet will
build and evaluate an approximate sample.

>>> import gpjax as gpx
>>> import jax.numpy as jnp
>>>
>>> parameter_state = gpx.initialise(prior)
>>> sample_fn = prior.sample_appox(10, parameter_state.params, key)
>>> sample_fn(jnp.linspace(0, 1, 100))

Args:
num_samples (int): The desired number of samples.
params (Dict): The specific set of parameters for which the sample
should be generated for.
key (KeyArray): The random seed used for the sample(s).
num_features (int): The number of features used when approximating the
kernel.


Returns:
FunctionalSample: A function representing an approximate sample from the Gaussian
process prior.
"""
if (not isinstance(num_features, int)) or num_features <= 0:
raise ValueError(f"num_features must be a positive integer")
if (not isinstance(num_samples, int)) or num_samples <= 0:
raise ValueError(f"num_samples must be a positive integer")

approximate_kernel = RFF(self.kernel, num_features)
approximate_kernel_params = approximate_kernel.init_params(key)
feature_weights = jax.random.normal(
key, [num_samples, 2 * num_features]
) # [B, L]

def sample_fn(test_inputs: Float[Array, "N D"]) -> Float[Array, "N B"]:

feature_evals = (
approximate_kernel.compute_engine.compute_features( # [N, L]
test_inputs,
frequencies=approximate_kernel_params["frequencies"],
scaling_factor=approximate_kernel_params["lengthscale"],
)
)
feature_evals *= jnp.sqrt(params["kernel"]["variance"] / num_features)
evaluated_sample = jnp.inner(feature_evals, feature_weights) # [N, B]
return (
self.mean_function(params["mean_function"], test_inputs)
+ evaluated_sample
)

return sample_fn

def init_params(self, key: KeyArray) -> Dict:
"""Initialise the GP prior's parameter set.

Expand All @@ -265,6 +349,8 @@ def init_params(self, key: KeyArray) -> Dict:
#######################
# GP Posteriors
#######################


class AbstractPosterior(AbstractPrior):
"""The base GP posterior object conditioned on an observed dataset. All
posterior objects should inherit from this class."""
Expand Down Expand Up @@ -431,7 +517,7 @@ def predict(
μx = mean_function(params["mean_function"], x)

# Precompute Gram matrix, Kxx, at training inputs, x
Kxx = kernel.gram(params["kernel"], x)
Kxx = kernel.gram(params["kerrspb nel"], x)
Kxx += identity(n) * jitter

# Σ = Kxx + Iσ²
Expand Down Expand Up @@ -466,12 +552,124 @@ def predict(test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt)
covariance += identity(n_test) * jitter

return GaussianDistribution(
jnp.atleast_1d(mean.squeeze()), covariance
)
return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)

return predict

def sample_approx(
self,
num_samples: int,
params: Dict,
train_data: Dataset,
key: KeyArray,
num_features: Optional[int] = 100,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is 100 sensible here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah actually, you dont need very many for the decoupled sampling!

) -> FunctionalSample:
"""Build an approximate sample from the Gaussian process posterior. This method
henrymoss marked this conversation as resolved.
Show resolved Hide resolved
provides a function that returns the evaluations of a sample across any given
inputs.

Unlike when building approximate samples from a Gaussian process prior, decompositions
based on Fourier features alone rarely give accurate samples. Therefore, we must also
include an additional set of features (known as canonical features) to better model the
transition from Gaussian process prior to Gaussian process posterior. For more details
see https://arxiv.org/pdf/2002.09309.pdf

In particular, we approximate the Gaussian processes' posterior as the finite feature
approximation

.. math:: \hat{f}(x) = \sum_{i=1}^m \phi_i(x)\theta_i + \sum{j=1}^N v_jk(.,x_j)


where :math:`\phi_i` are m features sampled from the Fourier feature decomposition of
the model's kernel and :math:`k(., x_j)` are N canonical features. The Fourier
weights :math:`\theta_i` are samples from a unit Gaussian.
See https://arxiv.org/pdf/2002.09309.pdf for expressions for the canonical
weights :math:`v_j`.


A key property of such functional samples is that the same sample draw is
evaluated for all queries. Consistency is a property that is prohibitively costly
to ensure when sampling exactly from the GP prior, as the cost of exact sampling
scales cubically with the size of the sample. In contrast, finite feature representations
can be evaluated with constant cost regardless of the required number of queries.

Args:
num_samples (int): The desired number of samples.
params (Dict): The specific set of parameters for which the sample
should be generated for.
key (KeyArray): The random seed used for the sample(s).
num_features (int): The number of features used when approximating the
kernel.


Returns:
FunctionalSample: A function representing an approximate sample from the Gaussian
process prior.
"""
if (not isinstance(num_features, int)) or num_features <= 0:
raise ValueError(f"num_features must be a positive integer")
if (not isinstance(num_samples, int)) or num_samples <= 0:
raise ValueError(f"num_samples must be a positive integer")

# Collect required quantities
jitter = get_global_config()["jitter"]
obs_noise = params["likelihood"]["obs_noise"]

# Approximate kernel with feature decomposition
approximate_kernel = RFF(self.prior.kernel, num_features)
approximate_kernel_params = approximate_kernel.init_params(key)

def eval_fourier_features(
test_inputs: Float[Array, "N D"]
) -> Float[Array, "N L"]:
Phi = approximate_kernel.compute_engine.compute_features( # [N, L]
test_inputs,
frequencies=approximate_kernel_params["frequencies"],
scaling_factor=approximate_kernel_params["lengthscale"],
)
Phi *= jnp.sqrt(params["kernel"]["variance"] / num_features)
return Phi

# sample weights for Fourier features
fourier_weights = jax.random.normal(
key, [num_samples, 2 * num_features]
) # [B, L]

# sample weights v for canonical features
# v = Σ⁻¹ (y + ε - ɸ⍵) for Σ = Kxx + Iσ² and ε ᯈ N(0, σ²)
Kxx = self.prior.kernel.gram(params["kernel"], train_data.X) # [N, N]
Sigma = Kxx + identity(train_data.n) * (obs_noise + jitter) # [N, N]
eps = jnp.sqrt(obs_noise) * jax.random.normal(
key, [train_data.n, num_samples]
) # [N, B]
y = train_data.y - self.prior.mean_function(
params["mean_function"], train_data.X
) # account for mean
Phi = eval_fourier_features(train_data.X)
canonical_weights = Sigma.solve(
y + eps - jnp.inner(Phi, fourier_weights)
) # [N, B]

def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
fourier_features = eval_fourier_features(test_inputs)
weight_space_contribution = jnp.inner(
fourier_features, fourier_weights
) # [n, B]
canonical_features = self.prior.kernel.cross_covariance(
params["kernel"], test_inputs, train_data.X
) # [n, N]
function_space_contribution = jnp.matmul(
canonical_features, canonical_weights
)

return (
self.prior.mean_function(params["mean_function"], test_inputs)
+ weight_space_contribution
+ function_space_contribution
)

return sample_fn

def marginal_log_likelihood(
self,
train_data: Dataset,
Expand Down Expand Up @@ -580,9 +778,7 @@ def mll(
)

return constant * (
marginal_likelihood.log_prob(
jnp.atleast_1d(y.squeeze())
).squeeze()
marginal_likelihood.log_prob(jnp.atleast_1d(y.squeeze())).squeeze()
)

return mll
Expand Down Expand Up @@ -630,9 +826,7 @@ def init_params(self, key: KeyArray) -> Dict:
self.prior.init_params(key),
{"likelihood": self.likelihood.init_params(key)},
)
parameters["latent"] = jnp.zeros(
shape=(self.likelihood.num_datapoints, 1)
)
parameters["latent"] = jnp.zeros(shape=(self.likelihood.num_datapoints, 1))
return parameters

def predict(
Expand Down Expand Up @@ -704,9 +898,7 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution:
covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
covariance += identity(n_test) * jitter

return GaussianDistribution(
jnp.atleast_1d(mean.squeeze()), covariance
)
return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)

return predict_fn

Expand Down
5 changes: 3 additions & 2 deletions gpjax/mean_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,9 @@ def init_params(self, key: KeyArray) -> Dict:

class Constant(AbstractMeanFunction):
"""
A zero mean function. This function returns a repeated scalar value for all inputs.
The scalar value itself can be treated as a model hyperparameter and learned during training.
A constant mean function. This function returns a repeated scalar value for all inputs.
The scalar value itself can be treated as a model hyperparameter and learned during training but
defaults to 1.0.
"""

def __init__(
Expand Down
9 changes: 9 additions & 0 deletions gpjax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import jaxutils
import deprecation
from typing import Callable
from jaxtyping import Array, Float

Dataset = deprecation.deprecated(
deprecated_in="0.5.5",
Expand All @@ -30,3 +32,10 @@


__all__ = ["Dataset" "verify_dataset"]


FunctionalSample = Callable[[Float[Array, "N D"]], Float[Array, "N B"]]
""" Type alias for functions representing `B` samples from a model, to be evaluated on any set of
`N` inputs (of dimension `D`) and returning the evaluations of each (potentially approximate)
sample draw across these inputs.
"""
Loading