diff --git a/gpjax/fit.py b/gpjax/fit.py index 6b272829..f4693736 100644 --- a/gpjax/fit.py +++ b/gpjax/fit.py @@ -131,7 +131,8 @@ def evaluate(self, model: LinearModel, train_data: gpx.Dataset) -> float: # Initialise solver state. solver.fun = _wrap_objective(solver.fun) - solver.options.pop("maxiter", None) # allow __post_init__ without jaxopt error + if hasattr(solver, "options"): # allow __post_init__ without weird jaxopt error + solver.options.pop("maxiter", None) solver.__post_init__() # needed to propagate changes to `fun` attribute if isinstance(solver, OptaxSolver): # hack for Optax compatibility