diff --git a/gpjax/decision_making/utility_functions/thompson_sampling.py b/gpjax/decision_making/utility_functions/thompson_sampling.py index 7367e978..965e4a1a 100644 --- a/gpjax/decision_making/utility_functions/thompson_sampling.py +++ b/gpjax/decision_making/utility_functions/thompson_sampling.py @@ -22,7 +22,10 @@ SinglePointUtilityFunction, ) from gpjax.decision_making.utils import OBJECTIVE -from gpjax.gps import ConjugatePosterior +from gpjax.gps import ( + ConjugatePosterior, + NonConjugatePosterior, +) from gpjax.typing import KeyArray @@ -56,7 +59,7 @@ def __post_init__(self): def build_utility_function( self, - posteriors: Mapping[str, ConjugatePosterior], + posteriors: Mapping[str, ConjugatePosterior | NonConjugatePosterior], datasets: Mapping[str, Dataset], key: KeyArray, ) -> SinglePointUtilityFunction: