-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add classes for handling posteriors during the decision making loop #362
Merged
Thomas-Christie
merged 3 commits into
JaxGaussianProcesses:tchristie/bo
from
Thomas-Christie:posterior-handling
Aug 24, 2023
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
from dataclasses import dataclass | ||
|
||
from beartype.typing import ( | ||
Callable, | ||
Optional, | ||
) | ||
import optax as ox | ||
|
||
import gpjax as gpx | ||
from gpjax.dataset import Dataset | ||
from gpjax.gps import ( | ||
AbstractLikelihood, | ||
AbstractPosterior, | ||
AbstractPrior, | ||
) | ||
from gpjax.objectives import AbstractObjective | ||
from gpjax.typing import KeyArray | ||
|
||
LikelihoodBuilder = Callable[[int], AbstractLikelihood] | ||
"""Type alias for likelihood builders, which take the number of datapoints as input and | ||
return a likelihood object initialised with the given number of datapoints.""" | ||
|
||
|
||
@dataclass | ||
class PosteriorHandler: | ||
""" | ||
Class for handling the creation and updating of a GP posterior as new data is | ||
observed. | ||
|
||
Attributes: | ||
prior (AbstractPrior): Prior to use when forming the posterior. | ||
likelihood_builder (LikelihoodBuilder): Function which takes the number of | ||
datapoints as input and returns a likelihood object initialised with the given | ||
number of datapoints. | ||
optimization_objective (AbstractObjective): Objective to use for optimizing the | ||
posterior hyperparameters. | ||
optimizer (ox.GradientTransformation): Optax optimizer to use for optimizing the | ||
posterior hyperparameters. | ||
num_optimization_iterations (int): Number of iterations to optimize | ||
the posterior hyperparameters for. | ||
""" | ||
|
||
prior: AbstractPrior | ||
likelihood_builder: LikelihoodBuilder | ||
optimization_objective: AbstractObjective | ||
optimizer: ox.GradientTransformation | ||
num_optimization_iters: int | ||
|
||
def __post_init__(self): | ||
if self.num_optimization_iters < 1: | ||
raise ValueError("num_optimization_iters must be greater than 0.") | ||
|
||
def get_posterior( | ||
self, dataset: Dataset, optimize: bool, key: Optional[KeyArray] = None | ||
) -> AbstractPosterior: | ||
""" | ||
Initialise (and optionally optimize) a posterior using the given dataset. | ||
|
||
Args: | ||
dataset (Dataset): Dataset to get posterior for. | ||
optimize (bool): Whether to optimize the posterior hyperparameters. | ||
key (Optional[KeyArray]): A JAX PRNG key which is used for optimizing the posterior | ||
hyperparameters. | ||
|
||
Returns: | ||
Posterior for the given dataset. | ||
""" | ||
posterior = self.prior * self.likelihood_builder(dataset.n) | ||
|
||
if optimize: | ||
if key is None: | ||
raise ValueError( | ||
"A key must be provided in order to optimize the posterior." | ||
) | ||
posterior = self._optimize_posterior(posterior, dataset, key) | ||
|
||
return posterior | ||
|
||
def update_posterior( | ||
self, | ||
dataset: Dataset, | ||
previous_posterior: AbstractPosterior, | ||
optimize: bool, | ||
key: Optional[KeyArray] = None, | ||
) -> AbstractPosterior: | ||
""" | ||
Update the given posterior with the given dataset. This needs to be done when | ||
the number of datapoints in the (training) dataset of the posterior changes, as | ||
the `AbstractLikelihood` class requires the number of datapoints to be specified. | ||
Hyperparameters may or may not be optimized, depending on the value of the | ||
`optimize` parameter. Note that the updated poterior will be initialised with | ||
the same prior hyperparameters as the previous posterior, but the likelihood | ||
will be re-initialised with the new number of datapoints, and hyperparameters | ||
set as in the `likelihood_builder` function. | ||
|
||
Args: | ||
dataset: Dataset to get posterior for. | ||
previous_posterior: Posterior being updated. This is supplied as one may | ||
wish to simply increase the number of datapoints in the likelihood, without | ||
optimizing the posterior hyperparameters, in which case the previous | ||
posterior can be used to obtain the previously set prior hyperparameters. | ||
optimize: Whether to optimize the posterior hyperparameters. | ||
key: A JAX PRNG key which is used for optimizing the posterior | ||
hyperparameters. | ||
""" | ||
posterior = previous_posterior.prior * self.likelihood_builder(dataset.n) | ||
|
||
if optimize: | ||
if key is None: | ||
raise ValueError( | ||
"A key must be provided in order to optimize the posterior." | ||
) | ||
posterior = self._optimize_posterior(posterior, dataset, key) | ||
return posterior | ||
|
||
def _optimize_posterior( | ||
self, posterior: AbstractPosterior, dataset: Dataset, key: KeyArray | ||
) -> AbstractPosterior: | ||
""" | ||
Takes a posterior and corresponding dataset and optimizes the posterior using the | ||
GPJax `fit` method. | ||
|
||
Args: | ||
posterior: Posterior being optimized. | ||
dataset: Dataset used for optimizing posterior. | ||
key: A JAX PRNG key for generating random numbers. | ||
Returns: | ||
Optimized posterior. | ||
""" | ||
opt_posterior, _ = gpx.fit( | ||
model=posterior, | ||
objective=self.optimization_objective, | ||
train_data=dataset, | ||
optim=self.optimizer, | ||
num_iters=self.num_optimization_iters, | ||
safe=True, | ||
key=key, | ||
verbose=False, | ||
) | ||
|
||
return opt_posterior |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
90 changes: 90 additions & 0 deletions
90
gpjax/decision_making/test_functions/non_conjugate_functions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright 2023 The GPJax Contributors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
from abc import abstractmethod | ||
from dataclasses import dataclass | ||
|
||
import jax.numpy as jnp | ||
import jax.random as jr | ||
from jaxtyping import ( | ||
Array, | ||
Float, | ||
Integer, | ||
) | ||
|
||
from gpjax.dataset import Dataset | ||
from gpjax.decision_making.search_space import ContinuousSearchSpace | ||
from gpjax.typing import KeyArray | ||
|
||
|
||
@dataclass | ||
class PoissonTestFunction: | ||
""" | ||
Test function for GPs utilising the Poisson likelihood. Function taken from | ||
https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset. | ||
|
||
Attributes: | ||
search_space (ContinuousSearchSpace): Search space for the function. | ||
""" | ||
|
||
search_space = ContinuousSearchSpace( | ||
lower_bounds=jnp.array([-2.0]), | ||
upper_bounds=jnp.array([2.0]), | ||
) | ||
|
||
def generate_dataset(self, num_points: int, key: KeyArray) -> Dataset: | ||
""" | ||
Generate a toy dataset from the test function. | ||
|
||
Args: | ||
num_points (int): Number of points to sample. | ||
key (KeyArray): JAX PRNG key. | ||
|
||
Returns: | ||
Dataset: Dataset of points sampled from the test function. | ||
""" | ||
X = self.search_space.sample(num_points=num_points, key=key) | ||
y = self.evaluate(X) | ||
return Dataset(X=X, y=y) | ||
|
||
def generate_test_points( | ||
self, num_points: int, key: KeyArray | ||
) -> Float[Array, "N D"]: | ||
""" | ||
Generate test points from the search space of the test function. | ||
|
||
Args: | ||
num_points (int): Number of points to sample. | ||
key (KeyArray): JAX PRNG key. | ||
|
||
Returns: | ||
Float[Array, 'N D']: Test points sampled from the search space. | ||
""" | ||
return self.search_space.sample(num_points=num_points, key=key) | ||
|
||
@abstractmethod | ||
def evaluate(self, x: Float[Array, "N 1"]) -> Integer[Array, "N 1"]: | ||
""" | ||
Evaluate the test function at a set of points. Function taken from | ||
https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset. | ||
|
||
Args: | ||
x (Float[Array, 'N D']): Points to evaluate the test function at. | ||
|
||
Returns: | ||
Integer[Array, 'N 1']: Values of the test function at the points. | ||
""" | ||
key = jr.PRNGKey(42) | ||
f = lambda x: 2.0 * jnp.sin(3 * x) + 0.5 * x | ||
return jr.poisson(key, jnp.exp(f(x))) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, so this steals the optimized kernel params from the last go?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is included to cover for a scenario where e.g. a user wishes to run the decision making loop, but doesn't want to optimise the posterior on each iteration of the loop. With this logic they can still update the likelihood, to reflect the change in dataset size, but without changing the prior parameters.