Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstract integrator #283

Merged
merged 14 commits into from
Jun 7, 2023
976 changes: 488 additions & 488 deletions docs/_static/bijector_figure.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/step_size_figure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
530 changes: 265 additions & 265 deletions docs/_static/step_size_figure.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/contributing.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Contributing
# Contributing

## How can I contribute?

Expand Down
51 changes: 32 additions & 19 deletions docs/scripts/sharp_bits_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,73 +18,86 @@
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.patches as patches
from matplotlib import patches

plt.style.use("../examples/gpjax.mplstyle")
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

# %%
fig, ax = plt.subplots()
ax.axhline(y = 0.25, color=cols[0], linewidth=1.5)
ax.axhline(y=0.25, color=cols[0], linewidth=1.5)

xs = [0.02, 0.06, 0.1, 0.17]
ys = np.ones_like(xs) * 0.25

ax.scatter(xs, ys, color=cols[1], marker="o", s=100, zorder=2)

for idx, x in enumerate(xs):
ax.annotate(text = f'$\ell_{{t-{idx+1}}}$', xy=(x, 0.25), xytext=(x+0.01, 0.275), ha='center', va='bottom')
ax.annotate(
text=f"$\\ell_{{t-{idx+1}}}$",
xy=(x, 0.25),
xytext=(x + 0.01, 0.275),
ha="center",
va="bottom",
)


style = "Simple, tail_width=0.5, head_width=4, head_length=8"
kw = dict(arrowstyle=style, color="k")

for i in range(len(xs)-1):
a = patches.FancyArrowPatch((xs[i+1], 0.25), (xs[i], 0.25), connectionstyle="arc3,rad=-.5", **kw)
for i in range(len(xs) - 1):
a = patches.FancyArrowPatch(
(xs[i + 1], 0.25), (xs[i], 0.25), connectionstyle="arc3,rad=-.5", **kw
)
ax.add_patch(a)


ax.scatter(-0.03, 0.25, color=cols[1], marker="x", s=100, linewidth=5, zorder=2)

a = patches.FancyArrowPatch((xs[0], 0.25), (-0.03, 0.25), connectionstyle="arc3,rad=-.5", **kw)
a = patches.FancyArrowPatch(
(xs[0], 0.25), (-0.03, 0.25), connectionstyle="arc3,rad=-.5", **kw
)
ax.add_patch(a)

ax.axvline(x = 0, color='black', linewidth=0.5, linestyle='-.')
ax.axvline(x=0, color="black", linewidth=0.5, linestyle="-.")
ax.get_yaxis().set_visible(False)
ax.spines["left"].set_visible(False)
ax.set_ylim(0., 0.5)
ax.set_ylim(0.0, 0.5)
ax.set_xlim(-0.07, 0.25)
plt.savefig('../_static/step_size_figure.svg', bbox_inches='tight')
plt.savefig("../_static/step_size_figure.png", bbox_inches="tight")

# %%
import tensorflow_probability.substrates.jax.bijectors as tfb
import jax.numpy as jnp

bij = tfb.Exp()

x = np.linspace(0.05, 3., 6)
x = np.linspace(0.05, 3.0, 6)
y = np.asarray(bij.inverse(x))
lval = 0.5
rval = 0.52

fig, ax = plt.subplots()
ax.scatter(x, np.ones_like(x)*lval, s=100, label='Constrained value')
ax.scatter(y, np.ones_like(y)*rval, marker='o', s=100, label='Unconstrained value')
ax.scatter(x, np.ones_like(x) * lval, s=100, label="Constrained value")
ax.scatter(y, np.ones_like(y) * rval, marker="o", s=100, label="Unconstrained value")

style = "Simple, tail_width=0.25, head_width=2, head_length=8"
for i in range(len(x)):
if i%2 != 0:
a = patches.FancyArrowPatch((x[i], lval), (y[i], rval), connectionstyle="arc3,rad=-.15", **kw)
if i % 2 != 0:
a = patches.FancyArrowPatch(
(x[i], lval), (y[i], rval), connectionstyle="arc3,rad=-.15", **kw
)
# a = patches.Arrow(lval, x[i], rval-lval, y[i]-x[i], width=0.05, color='k')
else:
a = patches.FancyArrowPatch((x[i], lval), (y[i], rval), connectionstyle="arc3,rad=.005", **kw)
a = patches.FancyArrowPatch(
(x[i], lval), (y[i], rval), connectionstyle="arc3,rad=.005", **kw
)
ax.add_patch(a)

ax.get_yaxis().set_visible(False)
ax.spines["left"].set_visible(False)
ax.legend(loc='best')
ax.legend(loc="best")
# ax.set_ylim(0.1, 0.32)
plt.savefig('../_static/bijector_figure.svg', bbox_inches='tight')
plt.savefig("../_static/bijector_figure.svg", bbox_inches="tight")

# %%
np.log(0.05)
Expand Down
20 changes: 10 additions & 10 deletions docs/sharp_bits.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ set's support and introduce a numerical and mathematical error into our model. F
example, consider the lengthscale parameter $`\ell`$, which we know must be strictly
positive. If at $`t^{\text{th}}`$ iterate, our current estimate of $`\ell`$ was
0.02 and our derivative informed us that $`\ell`$ should decrease, then if our
learning rate is greater is than 0.03, we would end up with a negative variance term.
learning rate is greater is than 0.03, we would end up with a negative variance term.
We visualise this issue below where the red cross denotes the invalid lengthscale value
that would be obtained, were we to optimise in the unconstrained parameter space.

Expand Down Expand Up @@ -90,25 +90,25 @@ their own bijectors and attach them to the parameter(s) of their model.

### Why is positive-definiteness important?

The Gram matrix of a kernel, a concept that we explore more in our
The Gram matrix of a kernel, a concept that we explore more in our
[kernels notebook](examples/kernels.py) and our [PyTree notebook](examples/pytrees.md), is a
symmetric positive definite matrix. As such, we
have a range of tools at our disposal to make subsequent operations on the covariance
matrix faster. One of these tools is the Cholesky factorisation that uniquely decomposes
any symmetric positive-definite matrix $`\mathbf{\Sigma}`$ by
```math
any symmetric positive-definite matrix $`\mathbf{\Sigma}`$ by

```math
\begin{align}
\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^{\top}\,,
\end{align}
```
where $`\mathbf{L}`$ is a lower triangular matrix.
where $`\mathbf{L}`$ is a lower triangular matrix.

We make use of this result in GPJax when solving linear systems of equations of the
form $`\mathbf{A}\boldsymbol{x} = \boldsymbol{b}`$. Whilst seemingly abstract at first,
such problems are frequently encountered when constructing Gaussian process models. One
such example is frequently encountered in the regression setting for learning Gaussian
process kernel hyperparameters. Here we have labels
process kernel hyperparameters. Here we have labels
$`\boldsymbol{y} \sim \mathcal{N}(f(\boldsymbol{x}), \sigma^2\mathbf{I})`$ with $`f(\boldsymbol{x}) \sim \mathcal{N}(\boldsymbol{0}, \mathbf{K}_{\boldsymbol{xx}})`$ arising from zero-mean
Gaussian process prior and Gram matrix $`\mathbf{K}_{\boldsymbol{xx}}`$ at the inputs
$`\boldsymbol{x}`$. Here the marginal log-likelihood comprises the following form
Expand Down Expand Up @@ -153,11 +153,11 @@ negative eigenvalues, this violates the requirements and results in a "Cholesky

To resolve this, we apply some numerical _jitter_ to the diagonals of any Gram matrix.
Typically this is very small, with $`10^{-6}`$ being the system default. However,
for some problems, this amount may need to be increased.
for some problems, this amount may need to be increased.

## Slow-to-evaluate

Famously, a regular Gaussian process model (as detailed in
Famously, a regular Gaussian process model (as detailed in
[our regression notebook](examples/regression.py)) will scale cubically in the number of data points.
Consequently, if you try to fit your Gaussian process model to a data set containing more
than several thousand data points, then you will likely incur a significant
Expand All @@ -175,5 +175,5 @@ above will become computationally infeasible. In such cases, we recommend using
uncollapsed evidence lower bound objective [@hensman2013gaussian] that allows stochastic
mini-batch optimisation of the parameters of your sparse Gaussian process model. Such a
model will scale linearly in the batch size and quadratically in the number of inducing
points. We demonstrate its use in
points. We demonstrate its use in
[our sparse stochastic variational inference notebook](examples/uncollapsed_vi.py).
3 changes: 3 additions & 0 deletions gpjax/base/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ def _get_trainables(meta_leaf):

return meta_map(_get_trainables, self)

def dict(self):
return {k: v for k, v in dataclasses.asdict(self).items()}

daniel-dodd marked this conversation as resolved.
Show resolved Hide resolved

def _toplevel_meta(pytree: Any) -> List[Optional[Dict[str, Any]]]:
"""Unpacks a list of meta corresponding to the top-level nodes of the pytree.
Expand Down
2 changes: 1 addition & 1 deletion gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def construct_posterior(
def _build_fourier_features_fn(
prior: Prior, num_features: int, key: KeyArray
) -> Callable[[Float[Array, "N D"]], Float[Array, "N L"]]:
"""Return a function that evaluates features sampled from the Fourier feature
r"""Return a function that evaluates features sampled from the Fourier feature
decomposition of the prior's kernel.

Args:
Expand Down
84 changes: 84 additions & 0 deletions gpjax/integrators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from abc import abstractmethod
from dataclasses import dataclass

from beartype.typing import (
Any,
Callable,
)
import jax.numpy as jnp
from jaxtyping import Float
import numpy as np
from simple_pytree import Pytree

from gpjax.typing import Array


@dataclass
class AbstractIntegrator(Pytree):
@abstractmethod
thomaspinder marked this conversation as resolved.
Show resolved Hide resolved
def integrate(
self,
fun: Callable,
y: Float[Array, "N D"],
mean: Float[Array, "N D"],
sigma2: Float[Array, "N D"],
**likelihood_params: Any,
):
raise NotImplementedError("self.integrate not implemented")

def __call__(
self,
fun: Callable,
y: Float[Array, "N D"],
mean: Float[Array, "N D"],
sigma2: Float[Array, "N D"],
*args: Any,
**kwargs: Any,
):
return self.integrate(fun, y, mean, sigma2, *args, **kwargs)


@dataclass
class GHQuadratureIntegrator(AbstractIntegrator):
num_points: int = 20

def integrate(
self,
fun: Callable,
y: Float[Array, "N D"],
mean: Float[Array, "N D"],
sigma2: Float[Array, "N D"],
thomaspinder marked this conversation as resolved.
Show resolved Hide resolved
**likelihood_params: Any,
) -> Float[Array, " N"]:
gh_points, gh_weights = np.polynomial.hermite.hermgauss(self.num_points)
sd = jnp.sqrt(sigma2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a potential numerical issue to keep an eye on. If sigma2 gets sufficiently small, the gradients will explode.

X = mean + jnp.sqrt(2.0) * sd * gh_points
W = gh_weights / jnp.sqrt(jnp.pi)
val = jnp.sum(fun(X, y) * W, axis=1)
return val


@dataclass
class AnalyticalGaussianIntegrator(AbstractIntegrator):
def integrate(
self,
fun: Callable,
y: Float[Array, "N D"],
mean: Float[Array, "N D"],
sigma2: Float[Array, "N D"],
**likelihood_params: Any,
daniel-dodd marked this conversation as resolved.
Show resolved Hide resolved
) -> Float[Array, " N"]:
obs_noise = likelihood_params["obs_noise"].squeeze()
sq_error = jnp.square(y - mean)
log2pi = jnp.log(2.0 * jnp.pi)
val = jnp.sum(
log2pi + jnp.log(obs_noise) + (sq_error + sigma2) / obs_noise, axis=1
)
return -0.5 * val


__all__ = [
"AbstractIntegrator",
"GHQuadratureIntegrator",
"AnalyticalGaussianIntegrator",
]
2 changes: 1 addition & 1 deletion gpjax/kernels/stationary/powered_exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PoweredExponential(AbstractKernel):
r"""The powered exponential family of kernels. This also equivalent to the symmetric generalized normal distribution.

See Diggle and Ribeiro (2007) - "Model-based Geostatistics".
and
and
https://en.wikipedia.org/wiki/Generalized_normal_distribution#Symmetric_version

"""
Expand Down
17 changes: 17 additions & 0 deletions gpjax/likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Any,
Union,
)
from jax import vmap
import jax.numpy as jnp
import jax.scipy as jsp
from jaxtyping import Float
Expand All @@ -29,6 +30,11 @@
static_field,
)
from gpjax.gaussian_distribution import GaussianDistribution
from gpjax.integrators import (
AbstractIntegrator,
AnalyticalGaussianIntegrator,
GHQuadratureIntegrator,
)
from gpjax.linops.utils import to_dense
from gpjax.typing import (
Array,
Expand All @@ -44,6 +50,7 @@ class AbstractLikelihood(Module):
r"""Abstract base class for likelihoods."""

num_datapoints: int = static_field()
integrator: AbstractIntegrator = static_field(GHQuadratureIntegrator())

def __call__(self, *args: Any, **kwargs: Any) -> tfd.Distribution:
r"""Evaluate the likelihood function at a given predictive distribution.
Expand Down Expand Up @@ -85,6 +92,15 @@ def link_function(self, f: Float[Array, "..."]) -> tfd.Distribution:
"""
raise NotImplementedError

def expected_log_likelihood(
self,
y: Float[Array, "N D"],
mu: Float[Array, "N D"],
sigma2: Float[Array, "N D"],
):
log_prob = vmap(lambda f, y: self.link_function(f).log_prob(y))
return self.integrator(fun=log_prob, y=y, mean=mu, sigma2=sigma2, **self.dict())

thomaspinder marked this conversation as resolved.
Show resolved Hide resolved

@dataclass
class Gaussian(AbstractLikelihood):
Expand All @@ -93,6 +109,7 @@ class Gaussian(AbstractLikelihood):
obs_noise: Union[ScalarFloat, Float[Array, "#N"]] = param_field(
jnp.array(1.0), bijector=tfb.Softplus()
)
integrator: AbstractIntegrator = static_field(AnalyticalGaussianIntegrator())

def link_function(self, f: Float[Array, "..."]) -> tfd.Normal:
r"""The link function of the Gaussian likelihood.
Expand Down
7 changes: 1 addition & 6 deletions gpjax/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from gpjax.dataset import Dataset
from gpjax.gaussian_distribution import GaussianDistribution
from gpjax.linops import identity
from gpjax.quadrature import gauss_hermite_quadrature
from gpjax.typing import (
Array,
ScalarFloat,
Expand Down Expand Up @@ -271,12 +270,8 @@ def q_moments(x):

mean, variance = vmap(q_moments)(x[:, None])

link_function = variational_family.posterior.likelihood.link_function
log_prob = vmap(lambda f, y: link_function(f).log_prob(y))

# ≈ ∫[log(p(y|f(x))) q(f(x))] df(x)
expectation = gauss_hermite_quadrature(log_prob, mean, jnp.sqrt(variance), y=y)

expectation = q.posterior.likelihood.expected_log_likelihood(y, mean, variance)
return expectation


Expand Down
Loading