From 0bc0435a0daecf8bea396ecef99cf42c90b082f4 Mon Sep 17 00:00:00 2001 From: Stephen Huan Date: Thu, 27 Jun 2024 03:14:00 -0400 Subject: [PATCH] fix(gpjax/thompson_sampling): add NonConjugatePosterior type annotation --- .../decision_making/utility_functions/thompson_sampling.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/gpjax/decision_making/utility_functions/thompson_sampling.py b/gpjax/decision_making/utility_functions/thompson_sampling.py index 7367e978..965e4a1a 100644 --- a/gpjax/decision_making/utility_functions/thompson_sampling.py +++ b/gpjax/decision_making/utility_functions/thompson_sampling.py @@ -22,7 +22,10 @@ SinglePointUtilityFunction, ) from gpjax.decision_making.utils import OBJECTIVE -from gpjax.gps import ConjugatePosterior +from gpjax.gps import ( + ConjugatePosterior, + NonConjugatePosterior, +) from gpjax.typing import KeyArray @@ -56,7 +59,7 @@ def __post_init__(self): def build_utility_function( self, - posteriors: Mapping[str, ConjugatePosterior], + posteriors: Mapping[str, ConjugatePosterior | NonConjugatePosterior], datasets: Mapping[str, Dataset], key: KeyArray, ) -> SinglePointUtilityFunction: