Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add classes for handling posteriors during the decision making loop #362

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gpjax/decision_making/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
AbstractAcquisitionMaximizer,
ContinuousAcquisitionMaximizer,
)
from gpjax.decision_making.posterior_handler import PosteriorHandler
from gpjax.decision_making.search_space import (
AbstractSearchSpace,
ContinuousSearchSpace,
Expand All @@ -42,6 +43,7 @@
"AbstractContinuousTestFunction",
"Forrester",
"LogarithmicGoldsteinPrice",
"PosteriorHandler",
"Quadratic",
"ThompsonSampling",
]
155 changes: 155 additions & 0 deletions gpjax/decision_making/posterior_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from dataclasses import dataclass

from beartype.typing import (
Callable,
Optional,
)
import optax as ox

import gpjax as gpx
from gpjax.dataset import Dataset
from gpjax.gps import (
AbstractLikelihood,
AbstractPosterior,
AbstractPrior,
)
from gpjax.objectives import AbstractObjective
from gpjax.typing import KeyArray

LikelihoodBuilder = Callable[[int], AbstractLikelihood]
"""Type alias for likelihood builders, which take the number of datapoints as input and
return a likelihood object initialised with the given number of datapoints."""


@dataclass
class PosteriorHandler:
"""
Class for handling the creation and updating of a GP posterior as new data is
observed.

Attributes:
prior (AbstractPrior): Prior to use when forming the posterior.
likelihood_builder (LikelihoodBuilder): Function which takes the number of
datapoints as input and returns a likelihood object initialised with the given
number of datapoints.
optimization_objective (AbstractObjective): Objective to use for optimizing the
posterior hyperparameters.
optimizer (ox.GradientTransformation): Optax optimizer to use for optimizing the
posterior hyperparameters.
num_optimization_iterations (int): Number of iterations to optimize
the posterior hyperparameters for.
"""

prior: AbstractPrior
likelihood_builder: LikelihoodBuilder
optimization_objective: AbstractObjective
optimizer: ox.GradientTransformation
num_optimization_iters: int

def __post_init__(self):
if self.num_optimization_iters < 1:
raise ValueError("num_optimization_iters must be greater than 0.")

def get_posterior(
self, dataset: Dataset, optimize: bool, key: Optional[KeyArray] = None
) -> AbstractPosterior:
"""
Initialise (and optionally optimize) a posterior using the given dataset.

Args:
dataset (Dataset): Dataset to get posterior for.
optimize (bool): Whether to optimize the posterior hyperparameters.
key (Optional[KeyArray]): A JAX PRNG key which is used for optimizing the posterior
hyperparameters.

Returns:
Posterior for the given dataset.
"""
posterior = self.prior * self.likelihood_builder(dataset.n)

if optimize:
if key is None:
raise ValueError(
"A key must be provided in order to optimize the posterior."
)
posterior = self._optimize_posterior(posterior, dataset, key)

return posterior

def update_posterior(
self,
dataset: Dataset,
previous_posterior: AbstractPosterior,
optimize: bool,
key: Optional[KeyArray] = None,
) -> AbstractPosterior:
"""
Update the given posterior with the given dataset. This needs to be done when
the number of datapoints in the (training) dataset of the posterior changes, as
the `AbstractLikelihood` class requires the number of datapoints to be specified.
Hyperparameters may or may not be optimized, depending on the value of the
`optimize` parameter. Note that the updated poterior will be initialised with
the same prior hyperparameters as the previous posterior, but the likelihood
will be re-initialised with the new number of datapoints, and hyperparameters
set as in the `likelihood_builder` function.

Args:
dataset: Dataset to get posterior for.
previous_posterior: Posterior being updated. This is supplied as one may
wish to simply increase the number of datapoints in the likelihood, without
optimizing the posterior hyperparameters, in which case the previous
posterior can be used to obtain the previously set prior hyperparameters.
optimize: Whether to optimize the posterior hyperparameters.
key: A JAX PRNG key which is used for optimizing the posterior
hyperparameters.
"""
posterior = previous_posterior.prior * self.likelihood_builder(dataset.n)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, so this steals the optimized kernel params from the last go?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is included to cover for a scenario where e.g. a user wishes to run the decision making loop, but doesn't want to optimise the posterior on each iteration of the loop. With this logic they can still update the likelihood, to reflect the change in dataset size, but without changing the prior parameters.


if optimize:
if key is None:
raise ValueError(
"A key must be provided in order to optimize the posterior."
)
posterior = self._optimize_posterior(posterior, dataset, key)
return posterior

def _optimize_posterior(
self, posterior: AbstractPosterior, dataset: Dataset, key: KeyArray
) -> AbstractPosterior:
"""
Takes a posterior and corresponding dataset and optimizes the posterior using the
GPJax `fit` method.

Args:
posterior: Posterior being optimized.
dataset: Dataset used for optimizing posterior.
key: A JAX PRNG key for generating random numbers.
Returns:
Optimized posterior.
"""
opt_posterior, _ = gpx.fit(
model=posterior,
objective=self.optimization_objective,
train_data=dataset,
optim=self.optimizer,
num_iters=self.num_optimization_iters,
safe=True,
key=key,
verbose=False,
)

return opt_posterior
5 changes: 4 additions & 1 deletion gpjax/decision_making/test_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from gpjax.decision_making.test_functions.continuous_functions import (
AbstractContinuousTestFunction,
Forrester,
LogarithmicGoldsteinPrice,
Quadratic,
)
from gpjax.decision_making.test_functions.non_conjugate_functions import (
PoissonTestFunction,
)

__all__ = [
"AbstractContinuousTestFunction",
"Forrester",
"LogarithmicGoldsteinPrice",
"PoissonTestFunction",
"Quadratic",
]
90 changes: 90 additions & 0 deletions gpjax/decision_making/test_functions/non_conjugate_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2023 The GPJax Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from abc import abstractmethod
from dataclasses import dataclass

import jax.numpy as jnp
import jax.random as jr
from jaxtyping import (
Array,
Float,
Integer,
)

from gpjax.dataset import Dataset
from gpjax.decision_making.search_space import ContinuousSearchSpace
from gpjax.typing import KeyArray


@dataclass
class PoissonTestFunction:
"""
Test function for GPs utilising the Poisson likelihood. Function taken from
https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset.

Attributes:
search_space (ContinuousSearchSpace): Search space for the function.
"""

search_space = ContinuousSearchSpace(
lower_bounds=jnp.array([-2.0]),
upper_bounds=jnp.array([2.0]),
)

def generate_dataset(self, num_points: int, key: KeyArray) -> Dataset:
"""
Generate a toy dataset from the test function.

Args:
num_points (int): Number of points to sample.
key (KeyArray): JAX PRNG key.

Returns:
Dataset: Dataset of points sampled from the test function.
"""
X = self.search_space.sample(num_points=num_points, key=key)
y = self.evaluate(X)
return Dataset(X=X, y=y)

def generate_test_points(
self, num_points: int, key: KeyArray
) -> Float[Array, "N D"]:
"""
Generate test points from the search space of the test function.

Args:
num_points (int): Number of points to sample.
key (KeyArray): JAX PRNG key.

Returns:
Float[Array, 'N D']: Test points sampled from the search space.
"""
return self.search_space.sample(num_points=num_points, key=key)

@abstractmethod
def evaluate(self, x: Float[Array, "N 1"]) -> Integer[Array, "N 1"]:
"""
Evaluate the test function at a set of points. Function taken from
https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset.

Args:
x (Float[Array, 'N D']): Points to evaluate the test function at.

Returns:
Integer[Array, 'N 1']: Values of the test function at the points.
"""
key = jr.PRNGKey(42)
f = lambda x: 2.0 * jnp.sin(3 * x) + 0.5 * x
return jr.poisson(key, jnp.exp(f(x)))
Loading