diff --git a/gpjax/decision_making/posterior_handler.py b/gpjax/decision_making/posterior_handler.py index 97b79f18..bc81ab51 100644 --- a/gpjax/decision_making/posterior_handler.py +++ b/gpjax/decision_making/posterior_handler.py @@ -144,7 +144,8 @@ def _optimize_posterior( # # We create a new solver state -> since the dataset (and therefore loss function) has changed! attributes = asdict(self.solver) - attributes["options"].pop("maxiter", None) # allow reinit without jaxopt error + if hasattr(attributes, "options"): # allow reinit without jaxopt error + attributes["options"].pop("maxiter", None) attributes.pop("fun", None) # pass in fun as callable rather than dict new_solver = self.solver.__class__(fun=self.solver.fun, **attributes)