diff --git a/gpjax/base/module.py b/gpjax/base/module.py index 3d381d3c..3af77825 100644 --- a/gpjax/base/module.py +++ b/gpjax/base/module.py @@ -87,7 +87,7 @@ def static_field( # noqa: PLR0913 ) -def _inherited_metadata(cls: type) -> Dict[str]: +def _inherited_metadata(cls: type) -> Dict[str, Any]: meta_data = dict() for parent_class in cls.mro(): if parent_class is not cls and parent_class is not Module: diff --git a/gpjax/decision_making/utility_functions/thompson_sampling.py b/gpjax/decision_making/utility_functions/thompson_sampling.py index 7367e978..965e4a1a 100644 --- a/gpjax/decision_making/utility_functions/thompson_sampling.py +++ b/gpjax/decision_making/utility_functions/thompson_sampling.py @@ -22,7 +22,10 @@ SinglePointUtilityFunction, ) from gpjax.decision_making.utils import OBJECTIVE -from gpjax.gps import ConjugatePosterior +from gpjax.gps import ( + ConjugatePosterior, + NonConjugatePosterior, +) from gpjax.typing import KeyArray @@ -56,7 +59,7 @@ def __post_init__(self): def build_utility_function( self, - posteriors: Mapping[str, ConjugatePosterior], + posteriors: Mapping[str, ConjugatePosterior | NonConjugatePosterior], datasets: Mapping[str, Dataset], key: KeyArray, ) -> SinglePointUtilityFunction: