Skip to content

Commit

Permalink
Start adding some check_shape decorations
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Jul 18, 2023
1 parent 9e25370 commit dc786e1
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 6 deletions.
27 changes: 27 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,21 @@

from __future__ import annotations

from typing import Iterable

import pytest
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from check_shapes import (
DocstringFormat,
ShapeCheckingState,
get_enable_check_shapes,
get_enable_function_call_precompute,
get_rewrite_docstrings,
set_enable_check_shapes,
set_enable_function_call_precompute,
set_rewrite_docstrings,
)


def pytest_addoption(parser: Parser) -> None:
Expand Down Expand Up @@ -74,3 +86,18 @@ def pytest_collection_modifyitems(config: Config, items: list[pytest.Item]) -> N
import tensorflow as tf

tf.config.experimental_run_functions_eagerly(True)


@pytest.fixture(autouse=True)
def enable_shape_checks() -> Iterable[None]:
# ensure `check_shapes` is always enabled for tests
old_enable = get_enable_check_shapes()
old_rewrite_docstrings = get_rewrite_docstrings()
old_function_call_precompute = get_enable_function_call_precompute()
set_enable_check_shapes(ShapeCheckingState.ENABLED)
set_rewrite_docstrings(DocstringFormat.SPHINX)
set_enable_function_call_precompute(True)
yield
set_enable_function_call_precompute(old_function_call_precompute)
set_rewrite_docstrings(old_rewrite_docstrings)
set_enable_check_shapes(old_enable)
5 changes: 3 additions & 2 deletions tests/unit/acquisition/function/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytest
import tensorflow as tf
import tensorflow_probability as tfp
from check_shapes.exceptions import ShapeMismatchError

from tests.util.misc import (
TF_DEBUGGING_ERROR_TYPES,
Expand Down Expand Up @@ -282,7 +283,7 @@ def test_expected_improvement_switches_to_improvement_on_feasible_points() -> No
def test_expected_improvement_raises_for_invalid_batch_size(at: TensorType) -> None:
ei = expected_improvement(QuadraticMeanAndRBFKernel(), tf.constant([1.0]))

with pytest.raises(TF_DEBUGGING_ERROR_TYPES):
with pytest.raises(ShapeMismatchError):
ei(at)


Expand Down Expand Up @@ -946,7 +947,7 @@ def test_expected_constrained_improvement_raises_for_invalid_batch_size(at: Tens

eci = builder.prepare_acquisition_function({NA: QuadraticMeanAndRBFKernel()}, datasets=data)

with pytest.raises(TF_DEBUGGING_ERROR_TYPES):
with pytest.raises(ShapeMismatchError):
eci(at)


Expand Down
9 changes: 5 additions & 4 deletions trieste/acquisition/function/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import tensorflow as tf
import tensorflow_probability as tfp
from check_shapes import check_shapes

from ...data import Dataset
from ...models import ProbabilisticModel, ReparametrizationSampler
Expand Down Expand Up @@ -211,12 +212,12 @@ def update(self, eta: TensorType) -> None:
"""Update the acquisition function with a new eta value."""
self._eta.assign(eta)

@check_shapes(
"x: [N..., 1, D] # This acquisition function only supports batch sizes of one",
"return: [N..., L]",
)
@tf.function
def __call__(self, x: TensorType) -> TensorType:
tf.debugging.assert_shapes(
[(x, [..., 1, None])],
message="This acquisition function only supports batch sizes of one.",
)
mean, variance = self._model.predict(tf.squeeze(x, -2))
normal = tfp.distributions.Normal(mean, tf.sqrt(variance))
return (self._eta - mean) * normal.cdf(self._eta) + variance * normal.prob(self._eta)
Expand Down

0 comments on commit dc786e1

Please sign in to comment.