From 898a4dda7be05b3fdb054f31c564a572ba19a79a Mon Sep 17 00:00:00 2001 From: Tom Savage Date: Sat, 29 Jul 2023 20:22:33 +0100 Subject: [PATCH] Fixed Polar GP example --- docs/examples/constructing_new_kernels.py | 24 ++++++++++------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/docs/examples/constructing_new_kernels.py b/docs/examples/constructing_new_kernels.py index e218a5f7..88f82852 100644 --- a/docs/examples/constructing_new_kernels.py +++ b/docs/examples/constructing_new_kernels.py @@ -218,27 +218,21 @@ class Polar(gpx.kernels.AbstractKernel): period: float = static_field(2 * jnp.pi) tau: float = param_field(jnp.array([4.0]), bijector=bij) - def __post_init__(self): - self.c = self.period / 2.0 - def __call__( self, x: Float[Array, "1 D"], y: Float[Array, "1 D"] ) -> Float[Array, "1"]: - t = angular_distance(x, y, self.c) - K = (1 + self.tau * t / self.c) * jnp.clip( - 1 - t / self.c, 0, jnp.inf - ) ** self.tau + c = self.period / 2.0 + t = angular_distance(x, y, c) + K = (1 + self.tau * t / c) * jnp.clip(1 - t / c, 0, jnp.inf) ** self.tau return K.squeeze() # %% [markdown] # We unpack this now to make better sense of it. In the kernel's initialiser # we specify the length of a single period. As the underlying -# domain is a circle, this is $2\pi$. Next, we define -# the Kernel's half-period parameter. As the kernel is a `dataclass` and `c` is -# function of `period`, we must define it in the `__post_init__` method. -# Finally, we define the kernel's `__call__` -# function which is a direct implementation of Equation (1). +# domain is a circle, this is $2\pi$. We then define the kernel's `__call__` +# function which is a direct implementation of Equation (1) where we define `c` +# as half the value of `period`. # # To constrain $\tau$ to be greater than 4, we use a `Softplus` bijector with a # clipped lower bound of 4.0. This is done by specifying the `bijector` argument @@ -267,11 +261,11 @@ def __call__( PKern = Polar() meanf = gpx.mean_functions.Zero() likelihood = gpx.Gaussian(num_datapoints=n) -circlular_posterior = gpx.Prior(mean_function=meanf, kernel=PKern) * likelihood +circular_posterior = gpx.Prior(mean_function=meanf, kernel=PKern) * likelihood # Optimise GP's marginal log-likelihood using Adam opt_posterior, history = gpx.fit( - model=circlular_posterior, + model=circular_posterior, objective=jit(gpx.ConjugateMLL(negative=True)), train_data=D, optim=ox.adamw(learning_rate=0.05), @@ -302,6 +296,7 @@ def __call__( alpha=0.3, label=r"1 Posterior s.d.", color=cols[1], + lw=0, ) ax.fill_between( angles.squeeze(), @@ -310,6 +305,7 @@ def __call__( alpha=0.15, label=r"3 Posterior s.d.", color=cols[1], + lw=0, ) ax.plot(angles, mu, label="Posterior mean") ax.scatter(D.X, D.y, alpha=1, label="Observations")