Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce beartype & fix types #230

Merged
merged 91 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
15b808f
add beartype dependency
st-- Apr 13, 2023
0eaec30
from typing import -> from beartype.typing import
st-- Apr 13, 2023
781bc38
jaxtyping import_hook for @jaxtyped @beartype everywhere
st-- Apr 13, 2023
67e315f
fix Type[] of class-as-argument
st-- Apr 13, 2023
e77732b
fix KeyArray type hint (should probably move into jaxutils though)
st-- Apr 13, 2023
5175fa7
fix return value of slice_input when active_dims is None
st-- Apr 13, 2023
c548d73
fix return value of squared_distance
st-- Apr 13, 2023
079936f
fix return type of recursive_bijectors
st-- Apr 13, 2023
1005212
fix slice_input type annotations
st-- Apr 13, 2023
0c57a4e
new KernelCallable type to fix kernel_fn annotations
st-- Apr 13, 2023
777b042
fix kernel __call__ annotation
st-- Apr 13, 2023
df88c69
fix KeyArray type hint
st-- Apr 13, 2023
711fcac
beartype does not like forward references; replaced with string types
st-- Apr 13, 2023
1c95fa5
linops other type hint fixes
st-- Apr 13, 2023
1b915cf
fix KeyArray
st-- Apr 13, 2023
c3e48c1
abstractions.py some type fixes
st-- Apr 13, 2023
4cc9db1
fix GaussianDistribution.log_prob return type
st-- Apr 13, 2023
4c4a572
fix depreciations & warnings
st-- Apr 13, 2023
c512195
Merge branch 'st/fix_depreciations' into st/beartype
st-- Apr 13, 2023
9516fdf
fix scalar array types
st-- Apr 13, 2023
2d45258
introduce ScalarBool, ScalarInt for jitted calls in abstractions
st-- Apr 13, 2023
ecfde64
relax LinearOperator's solve() types (can be both matrix or vector), …
st-- Apr 13, 2023
066a5d7
remove _stop_grad type hints, not sure what they should be
st-- Apr 13, 2023
ff68725
found some more
st-- Apr 13, 2023
4716773
Merge branch 'st/fix_depreciations' into st/beartype
st-- Apr 13, 2023
5c90755
float -> ScalarFloat fixes
st-- Apr 13, 2023
c692568
linops log_det type fixes
st-- Apr 13, 2023
ff5211a
some more linops type fixes
st-- Apr 13, 2023
f3cc17e
Merge remote-tracking branch 'upstream/v0.6' into st/beartype
st-- Apr 13, 2023
af69a73
actually commit KeyArray and Scalar* types
st-- Apr 17, 2023
cc97fe2
add beartype to pyproject
st-- Apr 17, 2023
97d6fdd
from beartype.typing import ...
st-- Apr 17, 2023
cd7c2ee
try to fix Self in gpjax/base/module
st-- Apr 17, 2023
8a62aae
fix _check_shape
st-- Apr 17, 2023
9359b05
gpjax.objectives: always import from gps and variational_families
st-- Apr 17, 2023
c687cfe
Revert "gpjax.objectives: always import from gps and variational_fami…
st-- Apr 17, 2023
96d3d1e
fix gpjax.objectives imported types
st-- Apr 17, 2023
fc18652
<...> | None not supported by beartype; replaced by Optional[<...>]
st-- Apr 17, 2023
4cbe562
gpjax.datasets: cannot specify strict array shape AND rely on _check_…
st-- Apr 17, 2023
16006ac
our tfd.Distribution subclassing requires the fix introduced in jaxty…
st-- Apr 17, 2023
c95badb
need to import base first!
st-- Apr 17, 2023
3c8969d
bugfix
st-- Apr 17, 2023
243e22f
AbstractKernel: string for forward references
st-- Apr 17, 2023
0c3ae8a
remove from __future__ import annotations
st-- Apr 19, 2023
cd6cf66
fix type annotations to make up for changes in 0c3ae8ac33e7938a2f4d5b…
st-- Apr 19, 2023
c31885e
pytree map functions may take a non-Module argument
st-- Apr 21, 2023
7083d08
ScalarFloat
st-- Apr 21, 2023
692d337
VecNOrMatNM
st-- Apr 21, 2023
1b4ae52
remove unnecessary / buggy methods
st-- Apr 21, 2023
7286cd0
more ScalarFloat
st-- Apr 21, 2023
6e55323
ScalarFloat
st-- Apr 21, 2023
8159a2d
type fixes
st-- Apr 21, 2023
62c7a69
fix shape type
st-- Apr 21, 2023
3a8814a
fix one KeyArray
st-- Apr 24, 2023
25ed38f
more ScalarFloat corrections in kernels
st-- Apr 24, 2023
76dd035
fix test_stationary accordingly for ScalarFloat params
st-- Apr 24, 2023
2ba27b2
fix return type
st-- Apr 24, 2023
33a3e7b
ScalarInt for Polynomial kernel and fix test for Scalar* params
st-- Apr 24, 2023
4dd9ed5
fix mock in test_abstract_variational_family
st-- Apr 24, 2023
54ef19c
fix link_function and variational_expectations shape annotations
st-- Apr 24, 2023
3a21676
minor test fix
st-- Apr 24, 2023
551f055
fix exception test for beartype
st-- Apr 24, 2023
6fc3a55
fix Constant mean function
st-- Apr 24, 2023
bf2a483
base_kernel as kwarg in test_approximations
st-- Apr 24, 2023
ac70773
rename func to test_ so it actually gets collected
st-- Apr 24, 2023
e2b6a88
mark test_graph_kernel as broken
st-- Apr 24, 2023
985d554
fix LinearOperator DTypeT
st-- Apr 24, 2023
7ea37d2
Revert "fix Constant mean function"
st-- Apr 24, 2023
760d470
fix test_mean_functions instead
st-- Apr 24, 2023
fcf6572
fix one more bug in RFF test
st-- Apr 24, 2023
78397d5
Self
st-- Apr 24, 2023
67f4c94
relax fit objective type
st-- Apr 24, 2023
68d6371
rename gpjax.utils -> gpjax.typing
st-- Apr 25, 2023
65a9bfb
Kernel = Any -> string forward reference
st-- Apr 25, 2023
c439a93
relax Gaussian.predict type annotation to include GaussianDistribution
st-- Apr 25, 2023
edaf2cb
our own `Array` type that accepts both JAX and Numpy arrays
st-- Apr 25, 2023
bb4d51f
some Float -> Num relaxations for graph kernel...
st-- Apr 25, 2023
ea1537d
ScalarFloat for GraphKernel hyperparams
st-- Apr 25, 2023
495c649
fix type hints to what happens (even if it seems wrong)
st-- Apr 25, 2023
dca6799
type relaxation for deep_kernels.pct.py
st-- Apr 25, 2023
f26d52e
some more minor consistency fixes
st-- Apr 25, 2023
23efd03
bugfix
st-- Apr 25, 2023
1077efa
Update examples/graph_kernels.pct.py
thomaspinder Apr 25, 2023
c292c7e
Update gpjax/dataset.py
thomaspinder Apr 25, 2023
7e16487
jaxtyping import hook for notebooks
st-- Apr 26, 2023
6620802
conftest.py to apply jaxtyping import hook before loading tests
st-- Apr 26, 2023
bc70fdf
remove import hook from gpjax/__init__
st-- Apr 26, 2023
249c9d0
Merge branch 'v0.6' of https://github.com/JaxGaussianProcesses/GPJax …
st-- Apr 26, 2023
c4a8af8
Update gpjax/dataset.py
st-- Apr 26, 2023
62abd1a
Update gpjax/dataset.py
thomaspinder Apr 26, 2023
42e21fb
fix tests of shape checks now that we have beartype
st-- Apr 26, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions gpjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,21 @@
# limitations under the License.
# ==============================================================================

from .dataset import Dataset
from .fit import fit
from .gps import Prior, construct_posterior
from .kernels import *
from .likelihoods import Bernoulli, Gaussian
from .mean_functions import Constant, Zero
from .objectives import (ELBO, CollapsedELBO, ConjugateMLL,
LogPosteriorDensity, NonConjugateMLL)
from .variational_families import (CollapsedVariationalGaussian,
ExpectationVariationalGaussian,
NaturalVariationalGaussian,
VariationalGaussian,
WhitenedVariationalGaussian)
from jaxtyping import install_import_hook
with install_import_hook("gpjax", "beartype.beartype"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to push beartype onto end users?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand well enough to comment here. What are the pros/cons of this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, mainly in some places we might have specified stricter types (array dtypes/shapes) than is strictly required by the code, so some code might have run fine if only beartype hadn't intervened. also, the explicit _check_shape error messages might be a bit more informative than the generic beartype ones. I don't know if there's also some more interaction with the jaxtyping @jaxtyped decorator.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if we didn't push beartype onto the end-user, then it would just be a testing utility for the package? I could easily imagine how any overly rigid beartype assertions could be annoying for an end user.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah! Though on the other hand, it could help us discover more bugs not handled by the tests...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, a bunch of them were only uncovered by the notebooks. But I guess we could add beartype there as well, and then use that as a way to suggest to users that they should use e.g. beartype (and report when one of our type hints is wrong)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good way to position it. Let's push it onto users then. If it becomes an issue, we can always walk it back with little major disruption.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Umm, do you mean force it on users (by including it in the general import), or strongly suggest it to users by having it at the start of every notebook (and inside the tests)? I think I'd be in favour of the latter actually...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latter - it's consistent with how we encourage people to use float64 without enforcing it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

from .dataset import Dataset
from .fit import fit
from .gps import Prior, construct_posterior
from .kernels import *
from .likelihoods import Bernoulli, Gaussian
from .mean_functions import Constant, Zero
from .objectives import (ELBO, CollapsedELBO, ConjugateMLL,
LogPosteriorDensity, NonConjugateMLL)
from .variational_families import (CollapsedVariationalGaussian,
ExpectationVariationalGaussian,
NaturalVariationalGaussian,
VariationalGaussian,
WhitenedVariationalGaussian)

__license__ = "MIT"
__description__ = "Didactic Gaussian processes in JAX"
Expand Down
43 changes: 22 additions & 21 deletions gpjax/base/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
# limitations under the License.
# ==============================================================================

from __future__ import annotations

__all__ = ["Module", "meta_leaves", "meta_flatten", "meta_map", "meta"]

import dataclasses
import os
from copy import copy, deepcopy
from typing import Any, Callable, Dict, Iterable, List, Tuple
from beartype.typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union

import jax
import jax.tree_util as jtu
Expand All @@ -31,7 +30,9 @@
PyTreeCheckpointer, PyTreeCheckpointHandler,
RestoreArgs, SaveArgs)
from simple_pytree import Pytree, static_field
from typing_extensions import Self
thomaspinder marked this conversation as resolved.
Show resolved Hide resolved


Self = TypeVar('T')


class Module(Pytree):
Expand All @@ -49,7 +50,7 @@ def __init_subclass__(cls, mutable: bool = False):
):
cls._pytree__meta[field] = {**value.metadata}

def replace(self, **kwargs: Any) -> Self:
def replace(self: Self, **kwargs: Any) -> Self:
"""
Replace the values of the fields of the object.

Expand All @@ -68,7 +69,7 @@ def replace(self, **kwargs: Any) -> Self:
pytree.__dict__.update(kwargs)
return pytree

def replace_meta(self, **kwargs: Any) -> Self:
def replace_meta(self: Self, **kwargs: Any) -> Self:
"""
Replace the metadata of the fields.

Expand All @@ -87,7 +88,7 @@ def replace_meta(self, **kwargs: Any) -> Self:
pytree.__dict__.update(_pytree__meta={**pytree._pytree__meta, **kwargs})
return pytree

def update_meta(self, **kwargs: Any) -> Self:
def update_meta(self: Self, **kwargs: Any) -> Self:
"""
Update the metadata of the fields. The metadata must already exist.

Expand All @@ -112,15 +113,15 @@ def update_meta(self, **kwargs: Any) -> Self:
pytree.__dict__.update(_pytree__meta=new)
return pytree

def replace_trainable(self: Module, **kwargs: Dict[str, bool]) -> Self:
def replace_trainable(self: Self, **kwargs: Dict[str, bool]) -> Self:
"""Replace the trainability status of local nodes of the Module."""
return self.update_meta(**{k: {"trainable": v} for k, v in kwargs.items()})

def replace_bijector(self: Module, **kwargs: Dict[str, tfb.Bijector]) -> Self:
def replace_bijector(self: Self, **kwargs: Dict[str, tfb.Bijector]) -> Self:
"""Replace the bijectors of local nodes of the Module."""
return self.update_meta(**{k: {"bijector": v} for k, v in kwargs.items()})

def constrain(self) -> Self:
def constrain(self: Self) -> Self:
"""Transform model parameters to the constrained space according to their defined bijectors.

Returns:
Expand All @@ -137,7 +138,7 @@ def _apply_constrain(meta_leaf):

return meta_map(_apply_constrain, self)

def unconstrain(self) -> Self:
def unconstrain(self: Self) -> Self:
"""Transform model parameters to the unconstrained space according to their defined bijectors.

Returns:
Expand All @@ -154,7 +155,7 @@ def _apply_unconstrain(meta_leaf):

return meta_map(_apply_unconstrain, self)

def stop_gradient(self) -> Self:
def stop_gradient(self: Self) -> Self:
"""Stop gradients flowing through the Module.

Returns:
Expand All @@ -176,7 +177,7 @@ def _apply_stop_grad(meta_leaf):
return meta_map(_apply_stop_grad, self)


def _toplevel_meta(pytree: Any) -> List[Dict[str, Any]]:
def _toplevel_meta(pytree: Any) -> List[Optional[Dict[str, Any]]]:
"""Unpacks a list of meta corresponding to the top-level nodes of the pytree.

Args:
Expand All @@ -197,7 +198,7 @@ def _toplevel_meta(pytree: Any) -> List[Dict[str, Any]]:
def meta_leaves(
pytree: Module,
*,
is_leaf: Callable[[Any], bool] | None = None,
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> List[Tuple[Dict[str, Any], Any]]:
"""
Returns the meta of the leaves of the pytree.
Expand All @@ -212,8 +213,8 @@ def meta_leaves(

def _unpack_metadata(
meta_leaf: Any,
pytree: Module,
is_leaf: Callable[[Any], bool] | None,
pytree: Union[Module, Any],
is_leaf: Optional[Callable[[Any], bool]],
):
"""Recursively unpack leaf metadata."""
if is_leaf and is_leaf(pytree):
Expand All @@ -235,8 +236,8 @@ def _unpack_metadata(


def meta_flatten(
pytree: Module, *, is_leaf: Callable[[Any], bool] | None = None
) -> Module:
pytree: Union[Module, Any], *, is_leaf: Optional[Callable[[Any], bool]] = None
) -> Union[Module, Any]:
"""
Returns the meta of the Module.

Expand All @@ -254,10 +255,10 @@ def meta_flatten(

def meta_map(
f: Callable[[Any, Dict[str, Any]], Any],
pytree: Module,
pytree: Union[Module, Any],
*rest: Any,
is_leaf: Callable[[Any], bool] | None = None,
) -> Module:
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> Union[Module, Any]:
"""Apply a function to a Module where the first argument are the pytree leaves, and the second argument are the Module metadata leaves.
Args:
f (Callable[[Any, Dict[str, Any]], Any]): The function to apply to the pytree.
Expand All @@ -273,7 +274,7 @@ def meta_map(
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))


def meta(pytree: Module, *, is_leaf: Callable[[Any], bool] | None = None) -> Module:
def meta(pytree: Module, *, is_leaf: Optional[Callable[[Any], bool]] = None) -> Module:
"""Returns the metadata of the Module as a pytree.

Args:
Expand Down
3 changes: 1 addition & 2 deletions gpjax/base/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
# limitations under the License.
# ==============================================================================

from __future__ import annotations
thomaspinder marked this conversation as resolved.
Show resolved Hide resolved

__all__ = ["param_field"]

import dataclasses
from typing import Any, Mapping, Optional
from beartype.typing import Any, Mapping, Optional

import tensorflow_probability.substrates.jax.bijectors as tfb

Expand Down
11 changes: 5 additions & 6 deletions gpjax/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional
from beartype.typing import Optional, Union

import jax.numpy as jnp
from jaxtyping import Array, Float
Expand All @@ -31,8 +30,8 @@ class Dataset(Pytree):
y (Optional[Float[Array, "N Q"]]): Output data.
"""

X: Optional[Float[Array, "N D"]] = None
y: Optional[Float[Array, "N Q"]] = None
X: Optional[Union[Float[Array, "N D"], Float[Array, "..."]]] = None
y: Optional[Union[Float[Array, "N Q"], Float[Array, "..."]]] = None
thomaspinder marked this conversation as resolved.
Show resolved Hide resolved

def __post_init__(self) -> None:
"""Checks that the shapes of X and y are compatible."""
Expand All @@ -54,7 +53,7 @@ def is_unsupervised(self) -> bool:
"""Returns `True` if the dataset is unsupervised."""
return self.X is None and self.y is not None

def __add__(self, other: Dataset) -> Dataset:
def __add__(self, other: "Dataset") -> "Dataset":
"""Combine two datasets. Right hand dataset is stacked beneath the left."""

X = None
Expand Down Expand Up @@ -84,7 +83,7 @@ def out_dim(self) -> int:
return self.y.shape[1]


def _check_shape(X: Float[Array, "N D"], y: Float[Array, "N Q"]) -> None:
def _check_shape(X: Optional[Float[Array, "..."]], y: Optional[Float[Array, "..."]]) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I removed the beartype-shape checking. Could instead simply remove the _check_shape function...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason not to remove _check_shape now we use beartype? If not, then let's do it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, beartype only checks arguments individually, and does not check consistency of dimensions across multiple arguments/return values, so might be better to keep the _check_shape (so e.g. "N1 D" and "N2 D" shapes get flagged as not-matching).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually actually, jaxtyping itself checks those, I just had a bug in my toy example trying it out. But the error messages are still more verbose/less precise than the ones emitted by _check_shape 😞

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I'm guessing there's no easy way for us to customise the error messages that Beartype throws?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the answer to my above question is no, then let's keep _check_shape.

Copy link
Collaborator

@thomaspinder thomaspinder Apr 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this is a very hot topic right now: patrick-kidger/jaxtyping#6.

"""Checks that the shapes of X and y are compatible."""
if X is not None and y is not None:
if X.shape[0] != y.shape[0]:
Expand Down
6 changes: 3 additions & 3 deletions gpjax/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
# limitations under the License.
# ==============================================================================

from typing import Any, Optional, Tuple
from beartype.typing import Any, Optional, Tuple

import jax
import jax.random as jr
import optax as ox
from jax._src.random import _check_prng_key
from jax.random import KeyArray
from jaxtyping import Array, Float
from jaxlib.xla_extension import PjitFunction
from warnings import warn

from gpjax.utils import ScalarFloat, KeyArray
from .base import Module
from .dataset import Dataset
from .objectives import AbstractObjective
Expand Down Expand Up @@ -117,7 +117,7 @@ def fit(
_check_verbose(verbose)

# Unconstrained space loss function with stop-gradient rule for non-trainable params.
def loss(model: Module, batch: Dataset) -> Float[Array, "1"]:
def loss(model: Module, batch: Dataset) -> ScalarFloat:
model = model.stop_gradient()
return objective(model.constrain(), batch)

Expand Down
32 changes: 17 additions & 15 deletions gpjax/gaussian_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.
# ==============================================================================

from typing import Any, Optional, Tuple

from beartype.typing import Any, Optional, Tuple

import jax.numpy as jnp
import jax.random as jr
from gpjax.utils import KeyArray
from gpjax.utils import ScalarFloat
from jax import vmap
from jax.random import KeyArray
from jaxtyping import Array, Float
import tensorflow_probability.substrates.jax as tfp

Expand Down Expand Up @@ -132,20 +134,20 @@ def event_shape(self) -> Tuple:
"""Returns the event shape."""
return self.loc.shape[-1:]

def entropy(self) -> Float[Array, "1"]:
def entropy(self) -> ScalarFloat:
"""Calculates the entropy of the distribution."""
return 0.5 * (
self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi)) + self.scale.log_det()
)

def log_prob(self, y: Float[Array, "N"]) -> Float[Array, "1"]:
def log_prob(self, y: Float[Array, "N"]) -> ScalarFloat:
"""Calculates the log pdf of the multivariate Gaussian.

Args:
y (Float[Array, "N"]): The value to calculate the log probability of.

Returns:
Float[Array, "1"]: The log probability of the value.
ScalarFloat: The log probability of the value.
"""
mu = self.loc
sigma = self.scale
Expand Down Expand Up @@ -179,11 +181,11 @@ def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:

return vmap(affine_transformation)(Z)

def sample(self,seed: KeyArray, sample_shape: Tuple[int, int]): # pylint: disable=useless-super-delegation
"""See `Distribution.sample`."""
return self._sample_n(seed, sample_shape[0])
def sample(self, seed: KeyArray, sample_shape: Tuple[int, ...]): # pylint: disable=useless-super-delegation
"""See `Distribution.sample`."""
return self._sample_n(seed, sample_shape[0]) # TODO this looks weird, why ignore the second entry?
st-- marked this conversation as resolved.
Show resolved Hide resolved

def kl_divergence(self, other: "GaussianDistribution") -> Float[Array, "1"]:
def kl_divergence(self, other: "GaussianDistribution") -> ScalarFloat:
return _kl_divergence(self, other)


Expand All @@ -200,14 +202,14 @@ def _check_and_return_dimension(
return q.event_shape[-1]


def _frobeinius_norm_squared(matrix: Float[Array, "N N"]) -> Float[Array, "1"]:
def _frobenius_norm_squared(matrix: Float[Array, "N N"]) -> ScalarFloat:
"""Calculates the squared Frobenius norm of a matrix."""
return jnp.sum(jnp.square(matrix))


def _kl_divergence(
q: GaussianDistribution, p: GaussianDistribution
) -> Float[Array, "1"]:
) -> ScalarFloat:
"""Computes the KL divergence, KL[q||p], between two multivariate Gaussian distributions
q(x) = N(x; μq, Σq) and p(x) = N(x; μp, Σp).

Expand All @@ -216,7 +218,7 @@ def _kl_divergence(
p (GaussianDistribution): A multivariate Gaussian distribution.

Returns:
Float[Array, "1"]: The KL divergence between q and p.
ScalarFloat: The KL divergence between q and p.
"""

n_dim = _check_and_return_dimension(q, p)
Expand All @@ -237,14 +239,14 @@ def _kl_divergence(
diff = mu_p - mu_q

# trace term, tr[Σp⁻¹ Σq] = tr[(LpLpᵀ)⁻¹(LqLqᵀ)] = tr[(Lp⁻¹Lq)(Lp⁻¹Lq)ᵀ] = (fr[LqLp⁻¹])²
trace = _frobeinius_norm_squared(
trace = _frobenius_norm_squared(
sqrt_p.solve(sqrt_q.to_dense())
) # TODO: Not most efficient, given the `to_dense()` call (e.g., consider diagonal p and q). Need to abstract solving linear operator against another linear operator.

# Mahalanobis term, (μp - μq)ᵀ Σp⁻¹ (μp - μq) = tr [(μp - μq)ᵀ [LpLpᵀ]⁻¹ (μp - μq)] = (fr[Lp⁻¹(μp - μq)])²
mahalanobis = _frobeinius_norm_squared(
mahalanobis = jnp.sum(jnp.square(
sqrt_p.solve(diff)
) # TODO: Need to improve this. Perhaps add a Mahalanobis method to ``LinearOperator``s.
)) # TODO: Need to improve this. Perhaps add a Mahalanobis method to ``LinearOperator``s.

# KL[q(x)||p(x)] = [ [(μp - μq)ᵀ Σp⁻¹ (μp - μq)] - n - log|Σq| + log|Σp| + tr[Σp⁻¹ Σq] ] / 2
return (mahalanobis - n_dim - sigma_q.log_det() + sigma_p.log_det() + trace) / 2.0
Expand Down
Loading