diff --git a/gpjax/fit.py b/gpjax/fit.py index 1aa48673..e5ad9d1d 100644 --- a/gpjax/fit.py +++ b/gpjax/fit.py @@ -19,6 +19,7 @@ Callable, Optional, Tuple, + TypeVar, Union, ) import jax @@ -36,11 +37,13 @@ ScalarFloat, ) +ModuleModel = TypeVar("ModuleModel", bound=Module) + def fit( # noqa: PLR0913 *, - model: Module, - objective: Union[AbstractObjective, Callable[[Module, Dataset], ScalarFloat]], + model: ModuleModel, + objective: Union[AbstractObjective, Callable[[ModuleModel, Dataset], ScalarFloat]], train_data: Dataset, optim: ox.GradientTransformation, key: KeyArray, @@ -50,7 +53,7 @@ def fit( # noqa: PLR0913 verbose: Optional[bool] = True, unroll: Optional[int] = 1, safe: Optional[bool] = True, -) -> Tuple[Module, Array]: +) -> Tuple[ModuleModel, Array]: r"""Train a Module model with respect to a supplied Objective function. Optimisers used here should originate from Optax. diff --git a/gpjax/gps.py b/gpjax/gps.py index d528737b..77d2e58a 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -13,8 +13,10 @@ # limitations under the License. # ============================================================================== +# from __future__ import annotations from abc import abstractmethod from dataclasses import dataclass +from typing import overload from beartype.typing import ( Any, @@ -43,6 +45,7 @@ from gpjax.likelihoods import ( AbstractLikelihood, Gaussian, + NonGaussianLikelihood, ) from gpjax.linops import identity from gpjax.mean_functions import AbstractMeanFunction @@ -134,7 +137,19 @@ class Prior(AbstractPrior): ``` """ - def __mul__(self, other: AbstractLikelihood): + @overload + def __mul__(self, other: Gaussian) -> "ConjugatePosterior": + ... + + @overload + def __mul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior": + ... + + @overload + def __mul__(self, other: AbstractLikelihood) -> "AbstractPosterior": + ... + + def __mul__(self, other): r"""Combine the prior with a likelihood to form a posterior distribution. The product of a prior and likelihood is proportional to the posterior @@ -168,7 +183,19 @@ def __mul__(self, other: AbstractLikelihood): """ return construct_posterior(prior=self, likelihood=other) - def __rmul__(self, other: AbstractLikelihood): + @overload + def __rmul__(self, other: Gaussian) -> "ConjugatePosterior": + ... + + @overload + def __rmul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior": + ... + + @overload + def __rmul__(self, other: AbstractLikelihood) -> "AbstractPosterior": + ... + + def __rmul__(self, other): r"""Combine the prior with a likelihood to form a posterior distribution. Reimplement the multiplication operator to allow for order-invariant @@ -654,9 +681,28 @@ def predict( ####################### # Utils ####################### + + +@overload +def construct_posterior(prior: Prior, likelihood: Gaussian) -> ConjugatePosterior: + ... + + +@overload +def construct_posterior( + prior: Prior, likelihood: NonGaussianLikelihood +) -> NonConjugatePosterior: + ... + + +@overload def construct_posterior( prior: Prior, likelihood: AbstractLikelihood ) -> AbstractPosterior: + ... + + +def construct_posterior(prior, likelihood): r"""Utility function for constructing a posterior object from a prior and likelihood. The function will automatically select the correct posterior object based on the likelihood. diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index 70780c5c..2d503020 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -207,8 +207,11 @@ def inv_probit(x: Float[Array, " *N"]) -> Float[Array, " *N"]: return 0.5 * (1.0 + jsp.special.erf(x / jnp.sqrt(2.0))) * (1 - 2 * jitter) + jitter +NonGaussianLikelihood = Union[Poisson, Bernoulli] + __all__ = [ "AbstractLikelihood", + "NonGaussianLikelihood", "Gaussian", "Bernoulli", "Poisson", diff --git a/tests/test_gps.py b/tests/test_gps.py index ba378b7d..a00a25fc 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -21,7 +21,10 @@ ValidationErrors = ValueError from dataclasses import is_dataclass -from typing import Callable +from typing import ( + Callable, + Type, +) from jax.config import config import jax.numpy as jnp @@ -214,7 +217,7 @@ def test_nonconjugate_posterior( @pytest.mark.parametrize("kernel", [RBF(), Matern52()]) @pytest.mark.parametrize("mean_function", [Zero(), Constant()]) def test_posterior_construct( - likelihood: AbstractLikelihood, + likelihood: Type[AbstractLikelihood], num_datapoints: int, mean_function: AbstractMeanFunction, kernel: AbstractKernel,