Skip to content

Commit

Permalink
Adds a simple test for PI
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgondu committed Jul 1, 2024
1 parent 623181a commit 5d0a84c
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ProbabilityOfImprovement(AbstractSinglePointUtilityFunctionBuilder):
of the objective function over the best observed value.
More precisely, given a predictive posterior distribution of the objective
function, the probability of improvement at a test point $`x`$ is defined as:
function $`f`$, the probability of improvement at a test point $`x`$ is defined as:
$$`\text{PI}(x) = \text{Prob}[f(x) < f(x_{\text{best}})]`$$
where $`x_{\text{best}}`$ is the minimizer of $`f`$ in the dataset.
Expand Down Expand Up @@ -80,7 +80,8 @@ def build_utility_function(
to form the utility function. Keys in `datasets` should correspond to
keys in `posteriors`. One of the datasets must correspond
to the `OBJECTIVE` key.
key (KeyArray): JAX PRNG key used for random number generation. Since the probability of improvement is computed deterministically
key (KeyArray): JAX PRNG key used for random number generation. Since
the probability of improvement is computed deterministically
from the predictive posterior, the key is not used.
Returns:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2023 The GPJax Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from jax import config

config.update("jax_enable_x64", True)

import jax.random as jr
import jax.numpy as jnp

from gpjax.decision_making.test_functions.continuous_functions import Forrester
from gpjax.decision_making.utility_functions.probability_of_improvement import (
ProbabilityOfImprovement,
)
from gpjax.decision_making.utils import OBJECTIVE
from tests.test_decision_making.utils import generate_dummy_conjugate_posterior


def test_probability_of_improvement_gives_correct_value_for_a_seed():
key = jr.key(42)
forrester = Forrester()
dataset = forrester.generate_dataset(num_points=10, key=key)
posterior = generate_dummy_conjugate_posterior(dataset)
posteriors = {OBJECTIVE: posterior}
datasets = {OBJECTIVE: dataset}

pi_utility_builder = ProbabilityOfImprovement()
pi_utility = pi_utility_builder.build_utility_function(
posteriors=posteriors, datasets=datasets, key=key
)

test_X = forrester.generate_test_points(num_points=10, key=key)
utility_values = pi_utility(test_X)

expected_utility_values = jnp.array(
[
7.30230451e-05,
5.00322831e-05,
1.06219741e-03,
2.19520435e-03,
3.49279363e-05,
1.66031943e-04,
2.78478912e-04,
3.35871920e-04,
1.38265233e-04,
3.63297977e-05,
]
).reshape(-1, 1)

assert utility_values.shape == (10, 1)
assert jnp.isclose(utility_values, expected_utility_values).all()

0 comments on commit 5d0a84c

Please sign in to comment.