Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce beartype & fix types #230

Merged
merged 91 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
15b808f
add beartype dependency
st-- Apr 13, 2023
0eaec30
from typing import -> from beartype.typing import
st-- Apr 13, 2023
781bc38
jaxtyping import_hook for @jaxtyped @beartype everywhere
st-- Apr 13, 2023
67e315f
fix Type[] of class-as-argument
st-- Apr 13, 2023
e77732b
fix KeyArray type hint (should probably move into jaxutils though)
st-- Apr 13, 2023
5175fa7
fix return value of slice_input when active_dims is None
st-- Apr 13, 2023
c548d73
fix return value of squared_distance
st-- Apr 13, 2023
079936f
fix return type of recursive_bijectors
st-- Apr 13, 2023
1005212
fix slice_input type annotations
st-- Apr 13, 2023
0c57a4e
new KernelCallable type to fix kernel_fn annotations
st-- Apr 13, 2023
777b042
fix kernel __call__ annotation
st-- Apr 13, 2023
df88c69
fix KeyArray type hint
st-- Apr 13, 2023
711fcac
beartype does not like forward references; replaced with string types
st-- Apr 13, 2023
1c95fa5
linops other type hint fixes
st-- Apr 13, 2023
1b915cf
fix KeyArray
st-- Apr 13, 2023
c3e48c1
abstractions.py some type fixes
st-- Apr 13, 2023
4cc9db1
fix GaussianDistribution.log_prob return type
st-- Apr 13, 2023
4c4a572
fix depreciations & warnings
st-- Apr 13, 2023
c512195
Merge branch 'st/fix_depreciations' into st/beartype
st-- Apr 13, 2023
9516fdf
fix scalar array types
st-- Apr 13, 2023
2d45258
introduce ScalarBool, ScalarInt for jitted calls in abstractions
st-- Apr 13, 2023
ecfde64
relax LinearOperator's solve() types (can be both matrix or vector), …
st-- Apr 13, 2023
066a5d7
remove _stop_grad type hints, not sure what they should be
st-- Apr 13, 2023
ff68725
found some more
st-- Apr 13, 2023
4716773
Merge branch 'st/fix_depreciations' into st/beartype
st-- Apr 13, 2023
5c90755
float -> ScalarFloat fixes
st-- Apr 13, 2023
c692568
linops log_det type fixes
st-- Apr 13, 2023
ff5211a
some more linops type fixes
st-- Apr 13, 2023
f3cc17e
Merge remote-tracking branch 'upstream/v0.6' into st/beartype
st-- Apr 13, 2023
af69a73
actually commit KeyArray and Scalar* types
st-- Apr 17, 2023
cc97fe2
add beartype to pyproject
st-- Apr 17, 2023
97d6fdd
from beartype.typing import ...
st-- Apr 17, 2023
cd7c2ee
try to fix Self in gpjax/base/module
st-- Apr 17, 2023
8a62aae
fix _check_shape
st-- Apr 17, 2023
9359b05
gpjax.objectives: always import from gps and variational_families
st-- Apr 17, 2023
c687cfe
Revert "gpjax.objectives: always import from gps and variational_fami…
st-- Apr 17, 2023
96d3d1e
fix gpjax.objectives imported types
st-- Apr 17, 2023
fc18652
<...> | None not supported by beartype; replaced by Optional[<...>]
st-- Apr 17, 2023
4cbe562
gpjax.datasets: cannot specify strict array shape AND rely on _check_…
st-- Apr 17, 2023
16006ac
our tfd.Distribution subclassing requires the fix introduced in jaxty…
st-- Apr 17, 2023
c95badb
need to import base first!
st-- Apr 17, 2023
3c8969d
bugfix
st-- Apr 17, 2023
243e22f
AbstractKernel: string for forward references
st-- Apr 17, 2023
0c3ae8a
remove from __future__ import annotations
st-- Apr 19, 2023
cd6cf66
fix type annotations to make up for changes in 0c3ae8ac33e7938a2f4d5b…
st-- Apr 19, 2023
c31885e
pytree map functions may take a non-Module argument
st-- Apr 21, 2023
7083d08
ScalarFloat
st-- Apr 21, 2023
692d337
VecNOrMatNM
st-- Apr 21, 2023
1b4ae52
remove unnecessary / buggy methods
st-- Apr 21, 2023
7286cd0
more ScalarFloat
st-- Apr 21, 2023
6e55323
ScalarFloat
st-- Apr 21, 2023
8159a2d
type fixes
st-- Apr 21, 2023
62c7a69
fix shape type
st-- Apr 21, 2023
3a8814a
fix one KeyArray
st-- Apr 24, 2023
25ed38f
more ScalarFloat corrections in kernels
st-- Apr 24, 2023
76dd035
fix test_stationary accordingly for ScalarFloat params
st-- Apr 24, 2023
2ba27b2
fix return type
st-- Apr 24, 2023
33a3e7b
ScalarInt for Polynomial kernel and fix test for Scalar* params
st-- Apr 24, 2023
4dd9ed5
fix mock in test_abstract_variational_family
st-- Apr 24, 2023
54ef19c
fix link_function and variational_expectations shape annotations
st-- Apr 24, 2023
3a21676
minor test fix
st-- Apr 24, 2023
551f055
fix exception test for beartype
st-- Apr 24, 2023
6fc3a55
fix Constant mean function
st-- Apr 24, 2023
bf2a483
base_kernel as kwarg in test_approximations
st-- Apr 24, 2023
ac70773
rename func to test_ so it actually gets collected
st-- Apr 24, 2023
e2b6a88
mark test_graph_kernel as broken
st-- Apr 24, 2023
985d554
fix LinearOperator DTypeT
st-- Apr 24, 2023
7ea37d2
Revert "fix Constant mean function"
st-- Apr 24, 2023
760d470
fix test_mean_functions instead
st-- Apr 24, 2023
fcf6572
fix one more bug in RFF test
st-- Apr 24, 2023
78397d5
Self
st-- Apr 24, 2023
67f4c94
relax fit objective type
st-- Apr 24, 2023
68d6371
rename gpjax.utils -> gpjax.typing
st-- Apr 25, 2023
65a9bfb
Kernel = Any -> string forward reference
st-- Apr 25, 2023
c439a93
relax Gaussian.predict type annotation to include GaussianDistribution
st-- Apr 25, 2023
edaf2cb
our own `Array` type that accepts both JAX and Numpy arrays
st-- Apr 25, 2023
bb4d51f
some Float -> Num relaxations for graph kernel...
st-- Apr 25, 2023
ea1537d
ScalarFloat for GraphKernel hyperparams
st-- Apr 25, 2023
495c649
fix type hints to what happens (even if it seems wrong)
st-- Apr 25, 2023
dca6799
type relaxation for deep_kernels.pct.py
st-- Apr 25, 2023
f26d52e
some more minor consistency fixes
st-- Apr 25, 2023
23efd03
bugfix
st-- Apr 25, 2023
1077efa
Update examples/graph_kernels.pct.py
thomaspinder Apr 25, 2023
c292c7e
Update gpjax/dataset.py
thomaspinder Apr 25, 2023
7e16487
jaxtyping import hook for notebooks
st-- Apr 26, 2023
6620802
conftest.py to apply jaxtyping import hook before loading tests
st-- Apr 26, 2023
bc70fdf
remove import hook from gpjax/__init__
st-- Apr 26, 2023
249c9d0
Merge branch 'v0.6' of https://github.com/JaxGaussianProcesses/GPJax …
st-- Apr 26, 2023
c4a8af8
Update gpjax/dataset.py
st-- Apr 26, 2023
62abd1a
Update gpjax/dataset.py
thomaspinder Apr 26, 2023
42e21fb
fix tests of shape checks now that we have beartype
st-- Apr 26, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/barycentres.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
import tensorflow_probability.substrates.jax.distributions as tfd
from jax.config import config

import gpjax as gpx
from jaxtyping import install_import_hook
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
Expand Down
4 changes: 3 additions & 1 deletion examples/classification.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from jax.config import config
from jaxtyping import Array, Float

import gpjax as gpx
from jaxtyping import install_import_hook
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
Expand Down
4 changes: 3 additions & 1 deletion examples/collapsed_vi.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from jax import jit
from jax.config import config

import gpjax as gpx
from jaxtyping import install_import_hook
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
Expand Down
17 changes: 9 additions & 8 deletions examples/deep_kernels.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
# Gaussian process model's kernel through a neural network can offer a solution to this.

# %%
import typing as tp
from dataclasses import dataclass, field
from typing import Dict, Any
from typing import Any

import jax
import jax.numpy as jnp
Expand All @@ -25,12 +24,14 @@
from simple_pytree import static_field
import flax

import gpjax as gpx
import gpjax.kernels as jk
from gpjax.kernels import DenseKernelComputation
from gpjax.kernels.base import AbstractKernel
from gpjax.kernels.computations import AbstractKernelComputation
from gpjax.base import param_field
from jaxtyping import install_import_hook
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
import gpjax.kernels as jk
from gpjax.kernels import DenseKernelComputation
from gpjax.kernels.base import AbstractKernel
from gpjax.kernels.computations import AbstractKernelComputation
from gpjax.base import param_field

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
Expand Down
10 changes: 6 additions & 4 deletions examples/graph_kernels.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from jax import jit
from jax.config import config

import gpjax as gpx
from jaxtyping import install_import_hook
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -85,9 +87,9 @@

true_kernel = gpx.GraphKernel(
laplacian=L,
lengthscale=jnp.array([2.3]),
variance=jnp.array([3.2]),
smoothness=jnp.array([6.1]),
lengthscale=2.3,
variance=3.2,
smoothness=6.1,
)
prior = gpx.Prior(mean_function=gpx.Zero(), kernel=true_kernel)

Expand Down
6 changes: 4 additions & 2 deletions examples/kernels.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from simple_pytree import static_field
import numpy as np

import gpjax as gpx
from gpjax.base.param import param_field
from jaxtyping import install_import_hook
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
from gpjax.base.param import param_field

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
Expand Down
4 changes: 3 additions & 1 deletion examples/regression.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from jax import jit
from jax.config import config

import gpjax as gpx
from jaxtyping import install_import_hook
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
Expand Down
9 changes: 6 additions & 3 deletions examples/spatial.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import fsspec
import geopandas as gpd
import gpjax as gpx
import jax
import jax.numpy as jnp
import jax.random as jr
Expand All @@ -33,11 +32,15 @@
import pystac_client
import rioxarray as rio
import xarray as xr
from gpjax.base import param_field
from gpjax.dataset import Dataset
from jaxtyping import Array, Float
from rioxarray.merge import merge_arrays

from jaxtyping import install_import_hook
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
from gpjax.base import param_field
from gpjax.dataset import Dataset

jax.config.update("jax_enable_x64", True)

key = jr.PRNGKey(123)
Expand Down
6 changes: 4 additions & 2 deletions examples/uncollapsed_vi.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
from jax import jit
from jax.config import config

import gpjax as gpx
import gpjax.kernels as jk
from jaxtyping import install_import_hook
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
import gpjax.kernels as jk

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
Expand Down
4 changes: 3 additions & 1 deletion examples/yacht.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import gpjax as gpx
from jaxtyping import install_import_hook
with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
Expand Down
45 changes: 23 additions & 22 deletions gpjax/base/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
# limitations under the License.
# ==============================================================================

from __future__ import annotations

__all__ = ["Module", "meta_leaves", "meta_flatten", "meta_map", "meta"]

import dataclasses
import os
from copy import copy, deepcopy
from typing import Any, Callable, Dict, Iterable, List, Tuple
from beartype.typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union

import jax
import jax.tree_util as jtu
Expand All @@ -31,7 +30,9 @@
PyTreeCheckpointer, PyTreeCheckpointHandler,
RestoreArgs, SaveArgs)
from simple_pytree import Pytree, static_field
from typing_extensions import Self
thomaspinder marked this conversation as resolved.
Show resolved Hide resolved


Self = TypeVar('Self')


class Module(Pytree):
Expand All @@ -49,7 +50,7 @@ def __init_subclass__(cls, mutable: bool = False):
):
cls._pytree__meta[field] = {**value.metadata}

def replace(self, **kwargs: Any) -> Self:
def replace(self: Self, **kwargs: Any) -> Self:
"""
Replace the values of the fields of the object.

Expand All @@ -68,7 +69,7 @@ def replace(self, **kwargs: Any) -> Self:
pytree.__dict__.update(kwargs)
return pytree

def replace_meta(self, **kwargs: Any) -> Self:
def replace_meta(self: Self, **kwargs: Any) -> Self:
"""
Replace the metadata of the fields.

Expand All @@ -87,7 +88,7 @@ def replace_meta(self, **kwargs: Any) -> Self:
pytree.__dict__.update(_pytree__meta={**pytree._pytree__meta, **kwargs})
return pytree

def update_meta(self, **kwargs: Any) -> Self:
def update_meta(self: Self, **kwargs: Any) -> Self:
"""
Update the metadata of the fields. The metadata must already exist.

Expand All @@ -112,15 +113,15 @@ def update_meta(self, **kwargs: Any) -> Self:
pytree.__dict__.update(_pytree__meta=new)
return pytree

def replace_trainable(self: Module, **kwargs: Dict[str, bool]) -> Self:
def replace_trainable(self: Self, **kwargs: Dict[str, bool]) -> Self:
"""Replace the trainability status of local nodes of the Module."""
return self.update_meta(**{k: {"trainable": v} for k, v in kwargs.items()})

def replace_bijector(self: Module, **kwargs: Dict[str, tfb.Bijector]) -> Self:
def replace_bijector(self: Self, **kwargs: Dict[str, tfb.Bijector]) -> Self:
"""Replace the bijectors of local nodes of the Module."""
return self.update_meta(**{k: {"bijector": v} for k, v in kwargs.items()})

def constrain(self) -> Self:
def constrain(self: Self) -> Self:
"""Transform model parameters to the constrained space according to their defined bijectors.

Returns:
Expand All @@ -137,7 +138,7 @@ def _apply_constrain(meta_leaf):

return meta_map(_apply_constrain, self)

def unconstrain(self) -> Self:
def unconstrain(self: Self) -> Self:
"""Transform model parameters to the unconstrained space according to their defined bijectors.

Returns:
Expand All @@ -154,7 +155,7 @@ def _apply_unconstrain(meta_leaf):

return meta_map(_apply_unconstrain, self)

def stop_gradient(self) -> Self:
def stop_gradient(self: Self) -> Self:
"""Stop gradients flowing through the Module.

Returns:
Expand All @@ -176,7 +177,7 @@ def _apply_stop_grad(meta_leaf):
return meta_map(_apply_stop_grad, self)


def _toplevel_meta(pytree: Any) -> List[Dict[str, Any]]:
def _toplevel_meta(pytree: Any) -> List[Optional[Dict[str, Any]]]:
"""Unpacks a list of meta corresponding to the top-level nodes of the pytree.

Args:
Expand All @@ -197,8 +198,8 @@ def _toplevel_meta(pytree: Any) -> List[Dict[str, Any]]:
def meta_leaves(
pytree: Module,
*,
is_leaf: Callable[[Any], bool] | None = None,
) -> List[Tuple[Dict[str, Any], Any]]:
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> List[Tuple[Optional[Dict[str, Any]], Any]]:
"""
Returns the meta of the leaves of the pytree.

Expand All @@ -212,8 +213,8 @@ def meta_leaves(

def _unpack_metadata(
meta_leaf: Any,
pytree: Module,
is_leaf: Callable[[Any], bool] | None,
pytree: Union[Module, Any],
is_leaf: Optional[Callable[[Any], bool]],
):
"""Recursively unpack leaf metadata."""
if is_leaf and is_leaf(pytree):
Expand All @@ -235,8 +236,8 @@ def _unpack_metadata(


def meta_flatten(
pytree: Module, *, is_leaf: Callable[[Any], bool] | None = None
) -> Module:
pytree: Union[Module, Any], *, is_leaf: Optional[Callable[[Any], bool]] = None
) -> Union[Module, Any]:
"""
Returns the meta of the Module.

Expand All @@ -254,10 +255,10 @@ def meta_flatten(

def meta_map(
f: Callable[[Any, Dict[str, Any]], Any],
pytree: Module,
pytree: Union[Module, Any],
*rest: Any,
is_leaf: Callable[[Any], bool] | None = None,
) -> Module:
is_leaf: Optional[Callable[[Any], bool]] = None,
) -> Union[Module, Any]:
"""Apply a function to a Module where the first argument are the pytree leaves, and the second argument are the Module metadata leaves.
Args:
f (Callable[[Any, Dict[str, Any]], Any]): The function to apply to the pytree.
Expand All @@ -273,7 +274,7 @@ def meta_map(
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))


def meta(pytree: Module, *, is_leaf: Callable[[Any], bool] | None = None) -> Module:
def meta(pytree: Module, *, is_leaf: Optional[Callable[[Any], bool]] = None) -> Module:
"""Returns the metadata of the Module as a pytree.

Args:
Expand Down
3 changes: 1 addition & 2 deletions gpjax/base/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
# limitations under the License.
# ==============================================================================

from __future__ import annotations
thomaspinder marked this conversation as resolved.
Show resolved Hide resolved

__all__ = ["param_field"]

import dataclasses
from typing import Any, Mapping, Optional
from beartype.typing import Any, Mapping, Optional

import tensorflow_probability.substrates.jax.bijectors as tfb

Expand Down
15 changes: 7 additions & 8 deletions gpjax/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional
from beartype.typing import Optional, Union

import jax.numpy as jnp
from jaxtyping import Array, Float
from jaxtyping import Float, Num
from simple_pytree import Pytree

from gpjax.typing import Array

@dataclass
class Dataset(Pytree):
Expand All @@ -31,8 +30,8 @@ class Dataset(Pytree):
y (Optional[Float[Array, "N Q"]]): Output data.
"""

X: Optional[Float[Array, "N D"]] = None
y: Optional[Float[Array, "N Q"]] = None
X: Optional[Num[Array, "N D"]] = None
y: Optional[Num[Array, "N Q"]] = None

def __post_init__(self) -> None:
"""Checks that the shapes of X and y are compatible."""
Expand All @@ -54,7 +53,7 @@ def is_unsupervised(self) -> bool:
"""Returns `True` if the dataset is unsupervised."""
return self.X is None and self.y is not None

def __add__(self, other: Dataset) -> Dataset:
def __add__(self, other: "Dataset") -> "Dataset":
"""Combine two datasets. Right hand dataset is stacked beneath the left."""

X = None
Expand Down Expand Up @@ -84,7 +83,7 @@ def out_dim(self) -> int:
return self.y.shape[1]


def _check_shape(X: Float[Array, "N D"], y: Float[Array, "N Q"]) -> None:
def _check_shape(X: Optional[Num[Array, "..."]], y: Optional[Num[Array, "..."]]) -> None:
"""Checks that the shapes of X and y are compatible."""
if X is not None and y is not None:
if X.shape[0] != y.shape[0]:
Expand Down
Loading