Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add search space class for BayesOpt #355

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions gpjax/bayes_opt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from gpjax.bayes_opt import search_space

__all__ = [
"search_space",
]
96 changes: 96 additions & 0 deletions gpjax/bayes_opt/search_space.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from abc import (
ABC,
abstractmethod,
)
from dataclasses import dataclass

from jaxtyping import Float
import tensorflow_probability.substrates.jax as tfp

from gpjax.typing import (
Array,
KeyArray,
)


@dataclass
class AbstractSearchSpace(ABC):
"""The `AbstractSearchSpace` class is an abstract base class for
search spaces, which are used to define domains for sampling and optimisation functionality in GPJax.
"""

@abstractmethod
def sample(self, num_points: int, key: KeyArray) -> Float[Array, "N D"]:
"""Sample points from the search space.
Args:
num_points (int): Number of points to be sampled from the search space.
key (KeyArray): JAX PRNG key.
Returns:
Float[Array, "N D"]: `num_points` points sampled from the search space.
"""
raise NotImplementedError

@property
@abstractmethod
def dimensionality(self) -> int:
"""Dimensionality of the search space.
Returns:
int: Dimensionality of the search space.
"""
raise NotImplementedError


@dataclass
class ContinuousSearchSpace(AbstractSearchSpace):
"""The `ContinuousSearchSpace` class is used to bound the domain of continuous real functions of dimension $`D`$."""

lower_bounds: Float[Array, " D"]
upper_bounds: Float[Array, " D"]

def __post_init__(self):
Thomas-Christie marked this conversation as resolved.
Show resolved Hide resolved
if not self.lower_bounds.dtype == self.upper_bounds.dtype:
raise ValueError("Lower and upper bounds must have the same dtype.")
if self.lower_bounds.shape != self.upper_bounds.shape:
raise ValueError("Lower and upper bounds must have the same shape.")
if self.lower_bounds.shape[0] == 0:
raise ValueError("Lower and upper bounds cannot be empty")
if not (self.lower_bounds <= self.upper_bounds).all():
raise ValueError("Lower bounds must be less than upper bounds.")

@property
def dimensionality(self) -> int:
return self.lower_bounds.shape[0]

def sample(self, num_points: int, key: KeyArray) -> Float[Array, "N D"]:
"""Sample points from the search space using a Halton sequence.

Args:
num_points (int): Number of points to be sampled from the search space.
key (KeyArray): JAX PRNG key.
Returns:
Float[Array, "N D"]: `num_points` points sampled using the Halton sequence
from the search space.
"""
if num_points <= 0:
raise ValueError("Number of points must be greater than 0.")

initial_sample = tfp.mcmc.sample_halton_sequence(
dim=self.dimensionality, num_results=num_points, seed=key
)
return (
self.lower_bounds + (self.upper_bounds - self.lower_bounds) * initial_sample
)
5 changes: 4 additions & 1 deletion gpjax/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
DiagonalKernelComputation,
EigenKernelComputation,
)
from gpjax.kernels.non_euclidean import GraphKernel, CatKernel
from gpjax.kernels.non_euclidean import (
CatKernel,
GraphKernel,
)
from gpjax.kernels.nonstationary import (
ArcCosine,
Linear,
Expand Down
2 changes: 1 addition & 1 deletion gpjax/kernels/non_euclidean/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

from gpjax.kernels.non_euclidean.graph import GraphKernel
from gpjax.kernels.non_euclidean.categorical import CatKernel
from gpjax.kernels.non_euclidean.graph import GraphKernel

__all__ = ["GraphKernel", "CatKernel"]
12 changes: 9 additions & 3 deletions gpjax/kernels/non_euclidean/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,23 @@


from dataclasses import dataclass
from typing import NamedTuple, Union
from typing import (
NamedTuple,
Union,
)

import jax.numpy as jnp
from jaxtyping import Float, Int
from jaxtyping import (
Float,
Int,
)
import tensorflow_probability.substrates.jax as tfp

from gpjax.base import (
param_field,
static_field,
)
from gpjax.kernels.base import AbstractKernel

from gpjax.typing import (
Array,
ScalarInt,
Expand Down
Empty file.
218 changes: 218 additions & 0 deletions tests/test_bayes_opt/test_search_space.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Copyright 2023 The GPJax Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from jax.config import config
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import (
Array,
Float,
)
import pytest

from gpjax.bayes_opt.search_space import (
AbstractSearchSpace,
ContinuousSearchSpace,
)

config.update("jax_enable_x64", True)


def test_abstract_search_space():
with pytest.raises(TypeError):
AbstractSearchSpace()


def test_continuous_search_space_empty_bounds():
with pytest.raises(ValueError):
ContinuousSearchSpace(lower_bounds=jnp.array([]), upper_bounds=jnp.array([]))


@pytest.mark.parametrize(
"lower_bounds, upper_bounds",
[
(jnp.array([0.0], dtype=jnp.float64), jnp.array([1.0], jnp.float32)),
(jnp.array([0.0], dtype=jnp.float32), jnp.array([1.0], jnp.float64)),
],
)
def test_continuous_search_space_dtype_consistency(
lower_bounds: Float[Array, " D"], upper_bounds: Float[Array, " D"]
):
with pytest.raises(ValueError):
ContinuousSearchSpace(lower_bounds=lower_bounds, upper_bounds=upper_bounds)


@pytest.mark.parametrize(
"lower_bounds, upper_bounds",
[
(jnp.array([0.0]), jnp.array([1.0, 1.0])),
(jnp.array([0.0, 0.0]), jnp.array([1.0])),
],
)
def test_continous_search_space_bounds_shape_consistency(
lower_bounds: Float[Array, " D"], upper_bounds: Float[Array, " D"]
):
with pytest.raises(ValueError):
ContinuousSearchSpace(lower_bounds=lower_bounds, upper_bounds=upper_bounds)


@pytest.mark.parametrize(
"lower_bounds, upper_bounds",
[
(jnp.array([1.0]), jnp.array([0.0])),
(jnp.array([1.0, 1.0]), jnp.array([0.0, 2.0])),
(jnp.array([1.0, 1.0]), jnp.array([2.0, 0.0])),
],
)
def test_continuous_search_space_bounds_values_consistency(
lower_bounds: Float[Array, " D"], upper_bounds: Float[Array, " D"]
):
with pytest.raises(ValueError):
ContinuousSearchSpace(lower_bounds=lower_bounds, upper_bounds=upper_bounds)


@pytest.mark.parametrize(
"continuous_search_space, dimensionality",
[
(ContinuousSearchSpace(jnp.array([0.0]), jnp.array([1.0])), 1),
(ContinuousSearchSpace(jnp.array([0.0, 0.0]), jnp.array([1.0, 1.0])), 2),
(
ContinuousSearchSpace(
jnp.array([0.0, 0.0, 0.0]), jnp.array([1.0, 1.0, 1.0])
),
3,
),
],
)
def test_continuous_search_space_dimensionality(
continuous_search_space: ContinuousSearchSpace, dimensionality: int
):
assert continuous_search_space.dimensionality == dimensionality


@pytest.mark.parametrize(
"continuous_search_space",
[
ContinuousSearchSpace(jnp.array([0.0]), jnp.array([1.0])),
ContinuousSearchSpace(jnp.array([0.0, 0.0]), jnp.array([1.0, 1.0])),
ContinuousSearchSpace(jnp.array([0.0, 0.0, 0.0]), jnp.array([1.0, 1.0, 1.0])),
],
)
@pytest.mark.parametrize("num_points", [0, -1])
def test_continous_search_space_invalid_sample_num_points(
continuous_search_space: ContinuousSearchSpace, num_points: int
):
with pytest.raises(ValueError):
continuous_search_space.sample(num_points=num_points, key=jr.PRNGKey(42))


@pytest.mark.parametrize(
"continuous_search_space, dimensionality",
[
(ContinuousSearchSpace(jnp.array([0.0]), jnp.array([1.0])), 1),
(ContinuousSearchSpace(jnp.array([0.0, 0.0]), jnp.array([1.0, 1.0])), 2),
(
ContinuousSearchSpace(
jnp.array([0.0, 0.0, 0.0]), jnp.array([1.0, 1.0, 1.0])
),
3,
),
],
)
@pytest.mark.parametrize("num_points", [1, 5, 50])
@pytest.mark.filterwarnings(
"ignore::UserWarning"
) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort
def test_continuous_search_space_sample_shape(
continuous_search_space: ContinuousSearchSpace, dimensionality: int, num_points: int
):
samples = continuous_search_space.sample(num_points=num_points, key=jr.PRNGKey(42))
assert samples.shape[0] == num_points
assert samples.shape[1] == dimensionality


@pytest.mark.parametrize(
"continuous_search_space",
[
ContinuousSearchSpace(jnp.array([0.0]), jnp.array([1.0])),
ContinuousSearchSpace(jnp.array([0.0, 0.0]), jnp.array([1.0, 1.0])),
ContinuousSearchSpace(jnp.array([0.0, 0.0, 0.0]), jnp.array([1.0, 1.0, 1.0])),
],
)
@pytest.mark.parametrize("key", [jr.PRNGKey(42), jr.PRNGKey(5)])
@pytest.mark.filterwarnings(
"ignore::UserWarning"
) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort
def test_continous_search_space_sample_same_key_same_samples(
continuous_search_space: ContinuousSearchSpace, key: jr.PRNGKey
):
sample_one = continuous_search_space.sample(num_points=100, key=key)
sample_two = continuous_search_space.sample(num_points=100, key=key)
assert jnp.array_equal(sample_one, sample_two)


@pytest.mark.parametrize(
"continuous_search_space",
[
ContinuousSearchSpace(jnp.array([0.0]), jnp.array([1.0])),
ContinuousSearchSpace(jnp.array([0.0, 0.0]), jnp.array([1.0, 1.0])),
ContinuousSearchSpace(jnp.array([0.0, 0.0, 0.0]), jnp.array([1.0, 1.0, 1.0])),
],
)
@pytest.mark.parametrize(
"key_one, key_two",
[(jr.PRNGKey(42), jr.PRNGKey(5)), (jr.PRNGKey(1), jr.PRNGKey(2))],
)
@pytest.mark.filterwarnings(
"ignore::UserWarning"
) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort
def test_continuous_search_space_different_keys_different_samples(
continuous_search_space: ContinuousSearchSpace,
key_one: jr.PRNGKey,
key_two: jr.PRNGKey,
):
sample_one = continuous_search_space.sample(num_points=100, key=key_one)
sample_two = continuous_search_space.sample(num_points=100, key=key_two)
assert not jnp.array_equal(sample_one, sample_two)


@pytest.mark.parametrize(
"continuous_search_space",
[
ContinuousSearchSpace(
lower_bounds=jnp.array([0.0]), upper_bounds=jnp.array([1.0])
),
ContinuousSearchSpace(
lower_bounds=jnp.array([0.0, 0.0]), upper_bounds=jnp.array([1.0, 2.0])
),
ContinuousSearchSpace(
lower_bounds=jnp.array([0.0, 1.0]), upper_bounds=jnp.array([2.0, 2.0])
),
ContinuousSearchSpace(
lower_bounds=jnp.array([2.4, 1.7, 4.9]),
upper_bounds=jnp.array([5.6, 1.8, 6.0]),
),
],
)
@pytest.mark.filterwarnings(
"ignore::UserWarning"
) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort
def test_continuous_search_space_valid_sample_ranges(
continuous_search_space: ContinuousSearchSpace,
):
samples = continuous_search_space.sample(num_points=100, key=jr.PRNGKey(42))
for i in range(continuous_search_space.dimensionality):
assert jnp.all(samples[:, i] >= continuous_search_space.lower_bounds[i])
assert jnp.all(samples[:, i] <= continuous_search_space.upper_bounds[i])
4 changes: 1 addition & 3 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,7 @@ def test_precision_warning(
if prec_y != jnp.float64:
expected_warnings += 1

with pytest.warns(
UserWarning, match=".* is not of type float64.*"
) as record:
with pytest.warns(UserWarning, match=".* is not of type float64.*") as record:
Dataset(X=x, y=y)

assert len(record) == expected_warnings
Loading