From 4fae37ecf0c529806794f727237d8ac3b54d2882 Mon Sep 17 00:00:00 2001 From: hmoss <32096840+henrymoss@users.noreply.github.com> Date: Fri, 22 Sep 2023 15:23:15 +0100 Subject: [PATCH] extra test --- tests/test_fit.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/test_fit.py b/tests/test_fit.py index ade942a4..6ca577f2 100644 --- a/tests/test_fit.py +++ b/tests/test_fit.py @@ -97,6 +97,32 @@ def step(self, model: LinearModel, train_data: Dataset) -> float: assert trained_model.bias == 1.0 +@pytest.mark.parametrize("batch_size", [10, 100]) +def test_raises_if_try_to_batch_scipy_optim(batch_size: int) -> None: + # Create dataset: + key = jr.PRNGKey(123) + x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(10, 1)), axis=0) + y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 + D = Dataset(X=x, y=y) + + # Define GP model: + prior = Prior(kernel=RBF(), mean_function=Constant()) + likelihood = Gaussian(num_datapoints=10) + posterior = prior * likelihood + + # Define loss function: + mll = ConjugateMLL(negative=True) + + with pytest.raises(ValueError): + fit( + model=posterior, + train_data=D, + solver=jaxopt.ScipyMinimize(fun=mll), + batch_size=batch_size, + key=jr.PRNGKey(123), + ) + + @pytest.mark.parametrize("num_iters", [1, 5]) @pytest.mark.parametrize("n_data", [1, 20]) @pytest.mark.parametrize("verbose", [True, False])