From 5f51cfe76fa0bec6dfd7fcbfc69495ff126b6818 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Fri, 22 Mar 2024 10:19:30 +0000 Subject: [PATCH] Update likelihoods_guide.py --- docs/examples/likelihoods_guide.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/docs/examples/likelihoods_guide.py b/docs/examples/likelihoods_guide.py index f1d36d0f..a04b2c52 100644 --- a/docs/examples/likelihoods_guide.py +++ b/docs/examples/likelihoods_guide.py @@ -1,3 +1,4 @@ +# %% [markdown] # # Likelihood guide # # In this notebook, we will walk users through the process of creating a new likelihood @@ -48,7 +49,7 @@ # these methods in the forthcoming sections, but first, we will show how to instantiate # a likelihood object. To do this, we'll need a dataset. -# + +# %% # Enable Float64 for more stable matrix inversions. from jax import config @@ -80,8 +81,8 @@ ax.plot(x, y, "o", label="Observations") ax.plot(x, f(x), label="Latent function") ax.legend() -# - +# %% [markdown] # In this example, our observations have support $[-3, 3]$ and are generated from a # sinusoidal function with Gaussian noise. As such, our response values $\mathbf{y}$ # range between $-1$ and $1$, subject to Gaussian noise. Due to this, a Gaussian @@ -92,8 +93,10 @@ # instantiating a likelihood object. We do this by specifying the `num_datapoints` # argument. +# %% gpx.likelihoods.Gaussian(num_datapoints=D.n) +# %% [markdown] # ### Likelihood parameters # # Some likelihoods, such as the Gaussian likelihood, contain parameters that we seek @@ -105,8 +108,10 @@ # initialise the likelihood standard deviation with a value of $0.5$, then we would do # this as follows: +# %% gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_stddev=0.5) +# %% [markdown] # To control other properties of the observation noise such as trainability and value # constraints, see our [PyTree guide](pytrees.md). # @@ -123,7 +128,7 @@ # samples of $\mathbf{f}^{\star}$, whilst in red we see samples of # $\mathbf{y}^{\star}$. -# + +# %% kernel = gpx.kernels.Matern32() meanf = gpx.mean_functions.Zero() prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf) @@ -153,11 +158,11 @@ color=cols[1], label="Predictive samples", ) -# - +# %% [markdown] # Similarly, for a Bernoulli likelihood function, the samples of $y$ would be binary. -# + +# %% likelihood = gpx.likelihoods.Bernoulli(num_datapoints=D.n) @@ -180,8 +185,8 @@ color=cols[1], label="Predictive samples", ) -# - +# %% [markdown] # ### Link functions # # In the above figure, we can see the latent samples being constrained to be either 0 or @@ -229,7 +234,7 @@ # this, let us consider a Gaussian likelihood where we'll first define a variational # approximation to the posterior. -# + +# %% z = jnp.linspace(-3.0, 3.0, 10).reshape(-1, 1) q = gpx.variational_families.VariationalGaussian(posterior=posterior, inducing_inputs=z) @@ -240,27 +245,32 @@ def q_moments(x): mean, variance = jax.vmap(q_moments)(x[:, None]) -# - +# %% [markdown] # Now that we have the variational mean and variational (co)variance, we can compute # the expected log-likelihood using the `expected_log_likelihood` method of the # likelihood object. +# %% jnp.sum(likelihood.expected_log_likelihood(y=y, mean=mean, variance=variance)) +# %% [markdown] # However, had we wanted to do this using quadrature, then we would have done the # following: +# %% lquad = gpx.likelihoods.Gaussian( num_datapoints=D.n, obs_stddev=jnp.array([0.1]), - integrator=gpx.integrators.GHQuadratureIntegrator(num_points=20), ) +# %% [markdown] # However, this is not recommended for the Gaussian likelihood given that the # expectation can be computed analytically. +# %% [markdown] # ## System configuration +# %% # %reload_ext watermark # %watermark -n -u -v -iv -w -a 'Thomas Pinder'