Skip to content

Commit

Permalink
Merge pull request #285 from JaxGaussianProcesses/static-typing-fixes
Browse files Browse the repository at this point in the history
Static typing fixes
  • Loading branch information
thomaspinder authored Jun 1, 2023
2 parents 558e797 + 5db6e20 commit 6138af4
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 7 deletions.
9 changes: 6 additions & 3 deletions gpjax/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Callable,
Optional,
Tuple,
TypeVar,
Union,
)
import jax
Expand All @@ -36,11 +37,13 @@
ScalarFloat,
)

ModuleModel = TypeVar("ModuleModel", bound=Module)


def fit( # noqa: PLR0913
*,
model: Module,
objective: Union[AbstractObjective, Callable[[Module, Dataset], ScalarFloat]],
model: ModuleModel,
objective: Union[AbstractObjective, Callable[[ModuleModel, Dataset], ScalarFloat]],
train_data: Dataset,
optim: ox.GradientTransformation,
key: KeyArray,
Expand All @@ -50,7 +53,7 @@ def fit( # noqa: PLR0913
verbose: Optional[bool] = True,
unroll: Optional[int] = 1,
safe: Optional[bool] = True,
) -> Tuple[Module, Array]:
) -> Tuple[ModuleModel, Array]:
r"""Train a Module model with respect to a supplied Objective function.
Optimisers used here should originate from Optax.
Expand Down
50 changes: 48 additions & 2 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.
# ==============================================================================

# from __future__ import annotations
from abc import abstractmethod
from dataclasses import dataclass
from typing import overload

from beartype.typing import (
Any,
Expand Down Expand Up @@ -43,6 +45,7 @@
from gpjax.likelihoods import (
AbstractLikelihood,
Gaussian,
NonGaussianLikelihood,
)
from gpjax.linops import identity
from gpjax.mean_functions import AbstractMeanFunction
Expand Down Expand Up @@ -134,7 +137,19 @@ class Prior(AbstractPrior):
```
"""

def __mul__(self, other: AbstractLikelihood):
@overload
def __mul__(self, other: Gaussian) -> "ConjugatePosterior":
...

@overload
def __mul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior":
...

@overload
def __mul__(self, other: AbstractLikelihood) -> "AbstractPosterior":
...

def __mul__(self, other):
r"""Combine the prior with a likelihood to form a posterior distribution.
The product of a prior and likelihood is proportional to the posterior
Expand Down Expand Up @@ -168,7 +183,19 @@ def __mul__(self, other: AbstractLikelihood):
"""
return construct_posterior(prior=self, likelihood=other)

def __rmul__(self, other: AbstractLikelihood):
@overload
def __rmul__(self, other: Gaussian) -> "ConjugatePosterior":
...

@overload
def __rmul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior":
...

@overload
def __rmul__(self, other: AbstractLikelihood) -> "AbstractPosterior":
...

def __rmul__(self, other):
r"""Combine the prior with a likelihood to form a posterior distribution.
Reimplement the multiplication operator to allow for order-invariant
Expand Down Expand Up @@ -654,9 +681,28 @@ def predict(
#######################
# Utils
#######################


@overload
def construct_posterior(prior: Prior, likelihood: Gaussian) -> ConjugatePosterior:
...


@overload
def construct_posterior(
prior: Prior, likelihood: NonGaussianLikelihood
) -> NonConjugatePosterior:
...


@overload
def construct_posterior(
prior: Prior, likelihood: AbstractLikelihood
) -> AbstractPosterior:
...


def construct_posterior(prior, likelihood):
r"""Utility function for constructing a posterior object from a prior and
likelihood. The function will automatically select the correct posterior
object based on the likelihood.
Expand Down
3 changes: 3 additions & 0 deletions gpjax/likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,11 @@ def inv_probit(x: Float[Array, " *N"]) -> Float[Array, " *N"]:
return 0.5 * (1.0 + jsp.special.erf(x / jnp.sqrt(2.0))) * (1 - 2 * jitter) + jitter


NonGaussianLikelihood = Union[Poisson, Bernoulli]

__all__ = [
"AbstractLikelihood",
"NonGaussianLikelihood",
"Gaussian",
"Bernoulli",
"Poisson",
Expand Down
7 changes: 5 additions & 2 deletions tests/test_gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
ValidationErrors = ValueError

from dataclasses import is_dataclass
from typing import Callable
from typing import (
Callable,
Type,
)

from jax.config import config
import jax.numpy as jnp
Expand Down Expand Up @@ -214,7 +217,7 @@ def test_nonconjugate_posterior(
@pytest.mark.parametrize("kernel", [RBF(), Matern52()])
@pytest.mark.parametrize("mean_function", [Zero(), Constant()])
def test_posterior_construct(
likelihood: AbstractLikelihood,
likelihood: Type[AbstractLikelihood],
num_datapoints: int,
mean_function: AbstractMeanFunction,
kernel: AbstractKernel,
Expand Down

0 comments on commit 6138af4

Please sign in to comment.