Skip to content

Commit

Permalink
wip (#243)
Browse files Browse the repository at this point in the history
  • Loading branch information
henrymoss authored Apr 30, 2023
1 parent 2aba697 commit 5381567
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 44 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ repos:
rev: 23.3.0
hooks:
- id: black
language_version: python3.8
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
Expand Down
7 changes: 6 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,14 @@ def find_version(*file_paths):
copyright = "2021, Thomas Pinder"
author = "Thomas Pinder"

from os.path import (
dirname,
join,
pardir,
)

# The full version, including alpha/beta/rc tags
import sys
from os.path import dirname, join, pardir

sys.path.insert(0, join(dirname(__file__), pardir))

Expand Down
8 changes: 7 additions & 1 deletion docs/conf_sphinx_patch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# This file is credited to the Flax authors.

from typing import Any, Dict, List, Set, Tuple
from typing import (
Any,
Dict,
List,
Set,
Tuple,
)

import sphinx.ext.autodoc
import sphinx.ext.autosummary.generate as ag
Expand Down
2 changes: 1 addition & 1 deletion gpjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from gpjax.likelihoods import (
Bernoulli,
Gaussian,
Poisson
Poisson,
)
from gpjax.mean_functions import (
Constant,
Expand Down
67 changes: 45 additions & 22 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from beartype.typing import (
Any,
Callable,
Dict,
Optional,
)
import jax.numpy as jnp
Expand Down Expand Up @@ -269,17 +268,18 @@ def sample_approx(
FunctionalSample: A function representing an approximate sample from the Gaussian
process prior.
"""
if (not isinstance(num_features, int)) or num_features <= 0:
raise ValueError("num_features must be a positive integer")

if (not isinstance(num_samples, int)) or num_samples <= 0:
raise ValueError("num_samples must be a positive integer")

approximate_kernel = RFF(base_kernel=self.kernel, num_basis_fns=num_features)
# sample fourier features
fourier_feature_fn = _build_fourier_features_fn(self, num_features, key)

# sample fourier weights
feature_weights = normal(key, [num_samples, 2 * num_features]) # [B, L]

def sample_fn(test_inputs: Float[Array, "N D"]) -> Float[Array, "N B"]:
feature_evals = approximate_kernel.compute_features(x=test_inputs)
feature_evals *= jnp.sqrt(self.kernel.variance / num_features)
feature_evals = fourier_feature_fn(test_inputs) # [N, L]
evaluated_sample = jnp.inner(feature_evals, feature_weights) # [N, B]
return self.mean_function(test_inputs) + evaluated_sample

Expand Down Expand Up @@ -501,24 +501,13 @@ def sample_approx(
FunctionalSample: A function representing an approximate sample from the Gaussian
process prior.
"""
if (not isinstance(num_features, int)) or num_features <= 0:
raise ValueError("num_features must be a positive integer")
if (not isinstance(num_samples, int)) or num_samples <= 0:
raise ValueError("num_samples must be a positive integer")

# Approximate kernel with feature decomposition
approximate_kernel = RFF(
base_kernel=self.prior.kernel, num_basis_fns=num_features
)

def eval_fourier_features(
test_inputs: Float[Array, "N D"]
) -> Float[Array, "N L"]:
Phi = approximate_kernel.compute_features(x=test_inputs)
Phi *= jnp.sqrt(self.prior.kernel.variance / num_features)
return Phi
# sample fourier features
fourier_feature_fn = _build_fourier_features_fn(self.prior, num_features, key)

# sample weights for Fourier features
# sample fourier weights
fourier_weights = normal(key, [num_samples, 2 * num_features]) # [B, L]

# sample weights v for canonical features
Expand All @@ -531,13 +520,13 @@ def eval_fourier_features(
key, [train_data.n, num_samples]
) # [N, B]
y = train_data.y - self.prior.mean_function(train_data.X) # account for mean
Phi = eval_fourier_features(train_data.X)
Phi = fourier_feature_fn(train_data.X)
canonical_weights = Sigma.solve(
y + eps - jnp.inner(Phi, fourier_weights)
) # [N, B]

def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
fourier_features = eval_fourier_features(test_inputs)
fourier_features = fourier_feature_fn(test_inputs) # [n, L]
weight_space_contribution = jnp.inner(
fourier_features, fourier_weights
) # [n, B]
Expand Down Expand Up @@ -634,6 +623,9 @@ def predict(
return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)


#######################
# Utils
#######################
def construct_posterior(
prior: Prior, likelihood: AbstractLikelihood
) -> AbstractPosterior:
Expand All @@ -658,6 +650,37 @@ def construct_posterior(
return NonConjugatePosterior(prior=prior, likelihood=likelihood)


def _build_fourier_features_fn(
prior: Prior, num_features: int, key: KeyArray
) -> Callable[[Float[Array, "N D"]], Float[Array, "N L"]]:
"""Return a function that evaluates features sampled from the Fourier feature
decomposition of the prior's kernel.
Args:
prior (Prior): The Prior distribution.
num_features (int): The number of feature functions to be sampled.
key (KeyArray): The random seed used.
Returns
-------
Callable: A callable function evaluation the sampled feature functions.
"""
if (not isinstance(num_features, int)) or num_features <= 0:
raise ValueError("num_features must be a positive integer")

# Approximate kernel with feature decomposition
approximate_kernel = RFF(
base_kernel=prior.kernel, num_basis_fns=num_features, key=key
)

def eval_fourier_features(test_inputs: Float[Array, "N D"]) -> Float[Array, "N L"]:
Phi = approximate_kernel.compute_features(x=test_inputs)
Phi *= jnp.sqrt(prior.kernel.variance / num_features)
return Phi

return eval_fourier_features


__all__ = [
"AbstractPrior",
"Prior",
Expand Down
27 changes: 10 additions & 17 deletions tests/test_gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
ValidationErrors = ValueError

from dataclasses import is_dataclass
import shutil
from typing import Callable

from jax.config import config
Expand All @@ -31,11 +30,6 @@
import pytest
import tensorflow_probability.substrates.jax.distributions as tfd

from gpjax.base import (
load_tree,
save_tree,
)

# from gpjax.dataset import Dataset
from gpjax.dataset import Dataset
from gpjax.gaussian_distribution import GaussianDistribution
Expand All @@ -50,15 +44,13 @@
from gpjax.kernels import (
RBF,
AbstractKernel,
Matern12,
Matern32,
Matern52,
)
from gpjax.likelihoods import (
AbstractLikelihood,
Bernoulli,
Gaussian,
Poisson
Poisson,
)
from gpjax.mean_functions import (
AbstractMeanFunction,
Expand Down Expand Up @@ -272,7 +264,7 @@ def test_posterior_construct(
@pytest.mark.parametrize("kernel", [RBF, Matern52])
@pytest.mark.parametrize("mean_function", [Zero(), Constant()])
def test_prior_sample_approx(num_datapoints, kernel, mean_function):
kern = kernel(lengthscale=5.0, variance=0.1)
kern = kernel(lengthscale=jnp.array([5.0, 1.0]), variance=0.1)
p = Prior(kernel=kern, mean_function=mean_function)
key = jr.PRNGKey(123)

Expand All @@ -292,7 +284,7 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function):
sampled_fn = p.sample_approx(1, key, 100)
assert isinstance(sampled_fn, Callable) # check type

x = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1)
x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2))
evals = sampled_fn(x)
assert evals.shape == (num_datapoints, 1.0) # check shape

Expand Down Expand Up @@ -325,16 +317,17 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function):
@pytest.mark.parametrize("kernel", [RBF, Matern52])
@pytest.mark.parametrize("mean_function", [Zero(), Constant()])
def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function):
kern = kernel(lengthscale=5.0, variance=0.1)
kern = kernel(lengthscale=jnp.array([5.0, 1.0]), variance=0.1)
p = Prior(kernel=kern, mean_function=mean_function) * Gaussian(
num_datapoints=num_datapoints
)
key = jr.PRNGKey(123)
x = jnp.sort(
jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 1)),
axis=0,

x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2))
y = (
jnp.mean(jnp.sin(x), 1, keepdims=True)
+ jr.normal(key=key, shape=(num_datapoints, 1)) * 0.1
)
y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1
D = Dataset(X=x, y=y)

with pytest.raises(ValueError):
Expand All @@ -353,7 +346,7 @@ def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function
sampled_fn = p.sample_approx(1, D, key, 100)
assert isinstance(sampled_fn, Callable) # check type

x = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1)
x = jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(num_datapoints, 2))
evals = sampled_fn(x)
assert evals.shape == (num_datapoints, 1.0) # check shape

Expand Down
1 change: 0 additions & 1 deletion tests/test_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ class TestPoisson(BaseTestLikelihood):
def _test_call_check(
likelihood: AbstractLikelihood, latent_mean, latent_cov, latent_dist
):

# Test call method.
pred_dist = likelihood(latent_dist)

Expand Down

0 comments on commit 5381567

Please sign in to comment.