From 503d0b6f66401ee34a8cb318e279fb0f858d726a Mon Sep 17 00:00:00 2001 From: Stephen Huan Date: Thu, 27 Jun 2024 03:12:10 -0400 Subject: [PATCH 1/2] fix(gpjax/base/module.py): add type of value --- gpjax/base/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpjax/base/module.py b/gpjax/base/module.py index 3d381d3c..3af77825 100644 --- a/gpjax/base/module.py +++ b/gpjax/base/module.py @@ -87,7 +87,7 @@ def static_field( # noqa: PLR0913 ) -def _inherited_metadata(cls: type) -> Dict[str]: +def _inherited_metadata(cls: type) -> Dict[str, Any]: meta_data = dict() for parent_class in cls.mro(): if parent_class is not cls and parent_class is not Module: From 0bc0435a0daecf8bea396ecef99cf42c90b082f4 Mon Sep 17 00:00:00 2001 From: Stephen Huan Date: Thu, 27 Jun 2024 03:14:00 -0400 Subject: [PATCH 2/2] fix(gpjax/thompson_sampling): add NonConjugatePosterior type annotation --- .../decision_making/utility_functions/thompson_sampling.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/gpjax/decision_making/utility_functions/thompson_sampling.py b/gpjax/decision_making/utility_functions/thompson_sampling.py index 7367e978..965e4a1a 100644 --- a/gpjax/decision_making/utility_functions/thompson_sampling.py +++ b/gpjax/decision_making/utility_functions/thompson_sampling.py @@ -22,7 +22,10 @@ SinglePointUtilityFunction, ) from gpjax.decision_making.utils import OBJECTIVE -from gpjax.gps import ConjugatePosterior +from gpjax.gps import ( + ConjugatePosterior, + NonConjugatePosterior, +) from gpjax.typing import KeyArray @@ -56,7 +59,7 @@ def __post_init__(self): def build_utility_function( self, - posteriors: Mapping[str, ConjugatePosterior], + posteriors: Mapping[str, ConjugatePosterior | NonConjugatePosterior], datasets: Mapping[str, Dataset], key: KeyArray, ) -> SinglePointUtilityFunction: