Skip to content

Commit

Permalink
Update likelihoods_guide.py
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-dodd committed Mar 22, 2024
1 parent fcc75a6 commit 5f51cfe
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions docs/examples/likelihoods_guide.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# %% [markdown]
# # Likelihood guide
#
# In this notebook, we will walk users through the process of creating a new likelihood
Expand Down Expand Up @@ -48,7 +49,7 @@
# these methods in the forthcoming sections, but first, we will show how to instantiate
# a likelihood object. To do this, we'll need a dataset.

# +
# %%
# Enable Float64 for more stable matrix inversions.
from jax import config

Expand Down Expand Up @@ -80,8 +81,8 @@
ax.plot(x, y, "o", label="Observations")
ax.plot(x, f(x), label="Latent function")
ax.legend()
# -

# %% [markdown]
# In this example, our observations have support $[-3, 3]$ and are generated from a
# sinusoidal function with Gaussian noise. As such, our response values $\mathbf{y}$
# range between $-1$ and $1$, subject to Gaussian noise. Due to this, a Gaussian
Expand All @@ -92,8 +93,10 @@
# instantiating a likelihood object. We do this by specifying the `num_datapoints`
# argument.

# %%
gpx.likelihoods.Gaussian(num_datapoints=D.n)

# %% [markdown]
# ### Likelihood parameters
#
# Some likelihoods, such as the Gaussian likelihood, contain parameters that we seek
Expand All @@ -105,8 +108,10 @@
# initialise the likelihood standard deviation with a value of $0.5$, then we would do
# this as follows:

# %%
gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_stddev=0.5)

# %% [markdown]
# To control other properties of the observation noise such as trainability and value
# constraints, see our [PyTree guide](pytrees.md).
#
Expand All @@ -123,7 +128,7 @@
# samples of $\mathbf{f}^{\star}$, whilst in red we see samples of
# $\mathbf{y}^{\star}$.

# +
# %%
kernel = gpx.kernels.Matern32()
meanf = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
Expand Down Expand Up @@ -153,11 +158,11 @@
color=cols[1],
label="Predictive samples",
)
# -

# %% [markdown]
# Similarly, for a Bernoulli likelihood function, the samples of $y$ would be binary.

# +
# %%
likelihood = gpx.likelihoods.Bernoulli(num_datapoints=D.n)


Expand All @@ -180,8 +185,8 @@
color=cols[1],
label="Predictive samples",
)
# -

# %% [markdown]
# ### Link functions
#
# In the above figure, we can see the latent samples being constrained to be either 0 or
Expand Down Expand Up @@ -229,7 +234,7 @@
# this, let us consider a Gaussian likelihood where we'll first define a variational
# approximation to the posterior.

# +
# %%
z = jnp.linspace(-3.0, 3.0, 10).reshape(-1, 1)
q = gpx.variational_families.VariationalGaussian(posterior=posterior, inducing_inputs=z)

Expand All @@ -240,27 +245,32 @@ def q_moments(x):


mean, variance = jax.vmap(q_moments)(x[:, None])
# -

# %% [markdown]
# Now that we have the variational mean and variational (co)variance, we can compute
# the expected log-likelihood using the `expected_log_likelihood` method of the
# likelihood object.

# %%
jnp.sum(likelihood.expected_log_likelihood(y=y, mean=mean, variance=variance))

# %% [markdown]
# However, had we wanted to do this using quadrature, then we would have done the
# following:

# %%
lquad = gpx.likelihoods.Gaussian(
num_datapoints=D.n,
obs_stddev=jnp.array([0.1]),
integrator=gpx.integrators.GHQuadratureIntegrator(num_points=20),
)

# %% [markdown]
# However, this is not recommended for the Gaussian likelihood given that the
# expectation can be computed analytically.

# %% [markdown]
# ## System configuration

# %%
# %reload_ext watermark
# %watermark -n -u -v -iv -w -a 'Thomas Pinder'

0 comments on commit 5f51cfe

Please sign in to comment.