Skip to content

Commit

Permalink
this was unbearable 🐾
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-dodd committed Oct 27, 2023
1 parent 831299c commit 999f08c
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 70 deletions.
35 changes: 15 additions & 20 deletions gpjax/citation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@
NonConjugateMLL,
)

CitationType = Union[str, Dict[str, str]]
CitationType = Union[None, str, Dict[str, str]]


@dataclass(repr=False)
class AbstractCitation:
citation_key: str = None
authors: str = None
title: str = None
year: str = None
citation_key: Union[str, None] = None
authors: Union[str, None] = None
title: Union[str, None] = None
year: Union[str, None] = None

def as_str(self) -> str:
citation_str = f"@{self.citation_type}{{{self.citation_key},"
Expand All @@ -64,29 +64,24 @@ def __str__(self) -> str:
)


class JittedFnCitation(AbstractCitation):
def __str__(self) -> str:
return "Citation not available for jitted objects."


@dataclass
class PhDThesisCitation(AbstractCitation):
school: str = None
institution: str = None
citation_type: str = "phdthesis"
school: Union[str, None] = None
institution: Union[str, None] = None
citation_type: CitationType = "phdthesis"


@dataclass
class PaperCitation(AbstractCitation):
booktitle: str = None
citation_type: str = "inproceedings"
booktitle: Union[str, None] = None
citation_type: CitationType = "inproceedings"


@dataclass
class BookCitation(AbstractCitation):
publisher: str = None
volume: str = None
citation_type: str = "book"
publisher: Union[str, None] = None
volume: Union[str, None] = None
citation_type: CitationType = "book"


####################
Expand All @@ -101,8 +96,8 @@ def cite(tree) -> AbstractCitation:
# Default citation
####################
@cite.register(PjitFunction)
def _(tree):
return JittedFnCitation()
def _(tree) -> None:
raise RuntimeError("Citation not available for jitted objects.")


####################
Expand Down
40 changes: 20 additions & 20 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================

# from __future__ import annotations
from abc import abstractmethod
from dataclasses import (
dataclass,
Expand All @@ -25,6 +24,7 @@
Any,
Callable,
Optional,
Union,
)
import cola
from cola.ops import Dense
Expand Down Expand Up @@ -152,17 +152,17 @@ class Prior(AbstractPrior):
```
"""

@overload
def __mul__(self, other: Gaussian) -> "ConjugatePosterior":
...
# @overload
# def __mul__(self, other: Gaussian) -> "ConjugatePosterior":
# ...

@overload
def __mul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior":
...
# @overload
# def __mul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior":
# ...

@overload
def __mul__(self, other: AbstractLikelihood) -> "AbstractPosterior":
...
# @overload
# def __mul__(self, other: AbstractLikelihood) -> "AbstractPosterior":
# ...

def __mul__(self, other):
r"""Combine the prior with a likelihood to form a posterior distribution.
Expand Down Expand Up @@ -198,17 +198,17 @@ def __mul__(self, other):
"""
return construct_posterior(prior=self, likelihood=other)

@overload
def __rmul__(self, other: Gaussian) -> "ConjugatePosterior":
...
# @overload
# def __rmul__(self, other: Gaussian) -> "ConjugatePosterior":
# ...

@overload
def __rmul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior":
...
# @overload
# def __rmul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior":
# ...

@overload
def __rmul__(self, other: AbstractLikelihood) -> "AbstractPosterior":
...
# @overload
# def __rmul__(self, other: AbstractLikelihood) -> "AbstractPosterior":
# ...

def __rmul__(self, other):
r"""Combine the prior with a likelihood to form a posterior distribution.
Expand Down Expand Up @@ -656,7 +656,7 @@ class NonConjugatePosterior(AbstractPosterior):
from, or optimise an approximation to, the posterior distribution.
"""

latent: Float[Array, "N 1"] = param_field(None)
latent: Union[Float[Array, "N 1"], None] = param_field(None)
key: KeyArray = static_field(PRNGKey(42))

def __post_init__(self):
Expand Down
21 changes: 14 additions & 7 deletions gpjax/integrators.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import (
TypeVar,
Union,
)

from beartype.typing import Callable
import jax.numpy as jnp
from jaxtyping import Float
import numpy as np

import gpjax
from gpjax.typing import Array

if TYPE_CHECKING:
import gpjax.likelihoods
Likelihood = TypeVar(
"Likelihood",
bound=Union["gpjax.likelihoods.AbstractLikelihood", None], # noqa: F821
)
Gaussian = TypeVar("Gaussian", bound="gpjax.likelihoods.Gaussian") # noqa: F821


@dataclass
Expand All @@ -24,7 +31,7 @@ def integrate(
y: Float[Array, "N D"],
mean: Float[Array, "N D"],
variance: Float[Array, "N D"],
likelihood: "gpjax.likelihoods.AbstractLikelihood" = None,
likelihood: Likelihood,
) -> Float[Array, " N"]:
r"""Integrate a function with respect to a Gaussian distribution.
Expand All @@ -47,7 +54,7 @@ def __call__(
y: Float[Array, "N D"],
mean: Float[Array, "N D"],
variance: Float[Array, "N D"],
likelihood: "gpjax.likelihoods.AbstractLikelihood" = None,
likelihood: Likelihood,
) -> Float[Array, " N"]:
r"""Integrate a function with respect to a Gaussian distribution.
Expand Down Expand Up @@ -86,7 +93,7 @@ def integrate(
y: Float[Array, "N D"],
mean: Float[Array, "N D"],
variance: Float[Array, "N D"],
likelihood: "gpjax.likelihoods.AbstractLikelihood" = None,
likelihood: Likelihood,
) -> Float[Array, " N"]:
r"""Compute a quadrature integral.
Expand Down Expand Up @@ -127,7 +134,7 @@ def integrate(
y: Float[Array, "N D"],
mean: Float[Array, "N D"],
variance: Float[Array, "N D"],
likelihood: "gpjax.likelihoods.Gaussian" = None,
likelihood: Gaussian,
) -> Float[Array, " N"]:
r"""Compute a Gaussian integral.
Expand Down
10 changes: 7 additions & 3 deletions gpjax/kernels/approximations/rff.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Compute Random Fourier Feature (RFF) kernel approximations. """
from dataclasses import dataclass

from beartype.typing import Union
from jax.random import PRNGKey
from jaxtyping import Float
import tensorflow_probability.substrates.jax.bijectors as tfb
Expand Down Expand Up @@ -34,9 +35,11 @@ class RFF(AbstractKernel):
- 'On the Error of Random Fourier Features' by Sutherland and Schneider (2015).
"""

base_kernel: AbstractKernel = None
base_kernel: Union[AbstractKernel, None] = None
num_basis_fns: int = static_field(50)
frequencies: Float[Array, "M 1"] = param_field(None, bijector=tfb.Identity())
frequencies: Union[Float[Array, "M D"], None] = param_field(
None, bijector=tfb.Identity()
)
compute_engine: BasisFunctionComputation = static_field(
BasisFunctionComputation(), repr=False
)
Expand All @@ -57,8 +60,9 @@ def __post_init__(self) -> None:
)
self.name = f"{self.base_kernel.name} (RFF)"

def __call__(self, x: Array, y: Array) -> Array:
def __call__(self, x: Float[Array, "D 1"], y: Float[Array, "D 1"]) -> None:
"""Superfluous for RFFs."""
raise RuntimeError("RFFs do not have a kernel function.")

def _check_valid_base_kernel(self, kernel: AbstractKernel):
r"""Verify that the base kernel is valid for RFF approximation.
Expand Down
2 changes: 1 addition & 1 deletion gpjax/kernels/non_euclidean/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class CatKernel(AbstractKernel):
cholesky_lower: Float[Array, "N N"] = param_field(
jnp.eye(2), bijector=tfb.CorrelationCholesky()
)
inspace_vals: list = static_field(None)
inspace_vals: Union[list, None] = static_field(None)
name: str = "Categorical Kernel"
input_1hot: bool = static_field(False)

Expand Down
9 changes: 5 additions & 4 deletions gpjax/kernels/non_euclidean/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from dataclasses import dataclass

from beartype.typing import Union
import jax.numpy as jnp
from jaxtyping import (
Float,
Expand Down Expand Up @@ -58,13 +59,13 @@ class GraphKernel(AbstractKernel):
of a graph.
"""

laplacian: Num[Array, "N N"] = static_field(None)
laplacian: Union[Num[Array, "N N"], None] = static_field(None)
lengthscale: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus())
variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus())
smoothness: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus())
eigenvalues: Float[Array, " N"] = static_field(None)
eigenvectors: Float[Array, "N N"] = static_field(None)
num_vertex: ScalarInt = static_field(None)
eigenvalues: Union[Float[Array, "N 1"], None] = static_field(None)
eigenvectors: Union[Float[Array, "N N"], None] = static_field(None)
num_vertex: Union[ScalarInt, None] = static_field(None)
compute_engine: AbstractKernelComputation = static_field(
EigenKernelComputation(), repr=False
)
Expand Down
29 changes: 20 additions & 9 deletions gpjax/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,28 @@

tfd = tfp.distributions

from typing import TypeVar

import cola

ConjugatePosterior = TypeVar(
"ConjugatePosterior", bound="gpjax.gps.ConjugatePosterior" # noqa: F821
)
NonConjugatePosterior = TypeVar(
"NonConjugatePosterior", bound="gpjax.gps.NonConjugatePosterior" # noqa: F821
)
VariationalFamily = TypeVar(
"VariationalFamily",
bound="gpjax.variational_families.AbstractVariationalFamily", # noqa: F821
)


@dataclass
class AbstractObjective(Module):
r"""Abstract base class for objectives."""

negative: bool = static_field(False)
constant: float = static_field(init=False, repr=False)
constant: ScalarFloat = static_field(init=False, repr=False)

def __post_init__(self) -> None:
self.constant = jnp.array(-1.0) if self.negative else jnp.array(1.0)
Expand All @@ -49,8 +62,8 @@ def step(self, *args, **kwargs) -> ScalarFloat:
class ConjugateMLL(AbstractObjective):
def step(
self,
posterior: "gpjax.gps.ConjugatePosterior", # noqa: F821
train_data: Dataset, # noqa: F821
posterior: ConjugatePosterior,
train_data: Dataset,
) -> ScalarFloat:
r"""Evaluate the marginal log-likelihood of the Gaussian process.
Expand Down Expand Up @@ -148,9 +161,7 @@ class LogPosteriorDensity(AbstractObjective):
sometimes referred to as the marginal log-likelihood.
"""

def step(
self, posterior: "gpjax.gps.NonConjugatePosterior", data: Dataset # noqa: F821
) -> ScalarFloat:
def step(self, posterior: NonConjugatePosterior, data: Dataset) -> ScalarFloat:
r"""Evaluate the log-posterior density of a Gaussian process.
Compute the marginal log-likelihood, or log-posterior density of the Gaussian
Expand Down Expand Up @@ -210,7 +221,7 @@ def step(
class ELBO(AbstractObjective):
def step(
self,
variational_family: "gpjax.variational_families.AbstractVariationalFamily", # noqa: F821
variational_family: VariationalFamily,
train_data: Dataset,
) -> ScalarFloat:
r"""Compute the evidence lower bound of a variational approximation.
Expand Down Expand Up @@ -249,7 +260,7 @@ def step(


def variational_expectation(
variational_family: "gpjax.variational_families.AbstractVariationalFamily", # noqa: F821
variational_family: VariationalFamily,
train_data: Dataset,
) -> Float[Array, " N"]:
r"""Compute the variational expectation.
Expand Down Expand Up @@ -304,7 +315,7 @@ class CollapsedELBO(AbstractObjective):

def step(
self,
variational_family: "gpjax.variational_families.AbstractVariationalFamily", # noqa: F821
variational_family: VariationalFamily,
train_data: Dataset,
) -> ScalarFloat:
r"""Compute a single step of the collapsed evidence lower bound.
Expand Down
7 changes: 5 additions & 2 deletions gpjax/variational_families.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
import abc
from dataclasses import dataclass

from beartype.typing import Any
from beartype.typing import (
Any,
Union,
)
import cola
import jax.numpy as jnp
import jax.scipy as jsp
Expand Down Expand Up @@ -106,7 +109,7 @@ class VariationalGaussian(AbstractVariationalGaussian):
$`\mu`$ and $`sqrt`$ with $`S = sqrt sqrt^{\top}`$.
"""

variational_mean: Float[Array, "N 1"] = param_field(None)
variational_mean: Union[Float[Array, "N 1"], None] = param_field(None)
variational_root_covariance: Float[Array, "N N"] = param_field(
None, bijector=tfb.FillTriangular()
)
Expand Down
6 changes: 2 additions & 4 deletions tests/test_citations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from gpjax.citation import (
AbstractCitation,
BookCitation,
JittedFnCitation,
NullCitation,
PaperCitation,
PhDThesisCitation,
Expand Down Expand Up @@ -209,6 +208,5 @@ def test_logarithmic_goldstein_price():
[gpx.ELBO(), gpx.CollapsedELBO(), gpx.LogPosteriorDensity(), gpx.ConjugateMLL()],
)
def test_jitted_fallback(objective):
citation = cite(jit(objective))
assert isinstance(citation, JittedFnCitation)
assert citation.__str__() == "Citation not available for jitted objects."
with pytest.raises(RuntimeError):
_ = cite(jit(objective))
Loading

0 comments on commit 999f08c

Please sign in to comment.