Skip to content

Commit

Permalink
extra test
Browse files Browse the repository at this point in the history
  • Loading branch information
henrymoss committed Sep 22, 2023
1 parent 57f0dc3 commit 4fae37e
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,32 @@ def step(self, model: LinearModel, train_data: Dataset) -> float:
assert trained_model.bias == 1.0


@pytest.mark.parametrize("batch_size", [10, 100])
def test_raises_if_try_to_batch_scipy_optim(batch_size: int) -> None:
# Create dataset:
key = jr.PRNGKey(123)
x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(10, 1)), axis=0)
y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1
D = Dataset(X=x, y=y)

# Define GP model:
prior = Prior(kernel=RBF(), mean_function=Constant())
likelihood = Gaussian(num_datapoints=10)
posterior = prior * likelihood

# Define loss function:
mll = ConjugateMLL(negative=True)

with pytest.raises(ValueError):
fit(
model=posterior,
train_data=D,
solver=jaxopt.ScipyMinimize(fun=mll),
batch_size=batch_size,
key=jr.PRNGKey(123),
)


@pytest.mark.parametrize("num_iters", [1, 5])
@pytest.mark.parametrize("n_data", [1, 20])
@pytest.mark.parametrize("verbose", [True, False])
Expand Down

0 comments on commit 4fae37e

Please sign in to comment.