diff --git a/examples/barycentres.pct.py b/examples/barycentres.pct.py index 45a78d0c..d996e9a2 100644 --- a/examples/barycentres.pct.py +++ b/examples/barycentres.pct.py @@ -24,7 +24,9 @@ import tensorflow_probability.substrates.jax.distributions as tfd from jax.config import config -import gpjax as gpx +from jaxtyping import install_import_hook +with install_import_hook("gpjax", "beartype.beartype"): + import gpjax as gpx # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) diff --git a/examples/classification.pct.py b/examples/classification.pct.py index 96f7f02e..7c136910 100644 --- a/examples/classification.pct.py +++ b/examples/classification.pct.py @@ -19,7 +19,9 @@ from jax.config import config from jaxtyping import Array, Float -import gpjax as gpx +from jaxtyping import install_import_hook +with install_import_hook("gpjax", "beartype.beartype"): + import gpjax as gpx # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) diff --git a/examples/collapsed_vi.pct.py b/examples/collapsed_vi.pct.py index 4d3677c4..493a1a7a 100644 --- a/examples/collapsed_vi.pct.py +++ b/examples/collapsed_vi.pct.py @@ -17,7 +17,9 @@ from jax import jit from jax.config import config -import gpjax as gpx +from jaxtyping import install_import_hook +with install_import_hook("gpjax", "beartype.beartype"): + import gpjax as gpx # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) diff --git a/examples/deep_kernels.pct.py b/examples/deep_kernels.pct.py index 86e09574..7b811b6a 100644 --- a/examples/deep_kernels.pct.py +++ b/examples/deep_kernels.pct.py @@ -9,9 +9,8 @@ # Gaussian process model's kernel through a neural network can offer a solution to this. # %% -import typing as tp from dataclasses import dataclass, field -from typing import Dict, Any +from typing import Any import jax import jax.numpy as jnp @@ -25,12 +24,14 @@ from simple_pytree import static_field import flax -import gpjax as gpx -import gpjax.kernels as jk -from gpjax.kernels import DenseKernelComputation -from gpjax.kernels.base import AbstractKernel -from gpjax.kernels.computations import AbstractKernelComputation -from gpjax.base import param_field +from jaxtyping import install_import_hook +with install_import_hook("gpjax", "beartype.beartype"): + import gpjax as gpx + import gpjax.kernels as jk + from gpjax.kernels import DenseKernelComputation + from gpjax.kernels.base import AbstractKernel + from gpjax.kernels.computations import AbstractKernelComputation + from gpjax.base import param_field # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) diff --git a/examples/graph_kernels.pct.py b/examples/graph_kernels.pct.py index daf8444f..a358b36c 100644 --- a/examples/graph_kernels.pct.py +++ b/examples/graph_kernels.pct.py @@ -18,7 +18,9 @@ from jax import jit from jax.config import config -import gpjax as gpx +from jaxtyping import install_import_hook +with install_import_hook("gpjax", "beartype.beartype"): + import gpjax as gpx # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -85,9 +87,9 @@ true_kernel = gpx.GraphKernel( laplacian=L, - lengthscale=jnp.array([2.3]), - variance=jnp.array([3.2]), - smoothness=jnp.array([6.1]), + lengthscale=2.3, + variance=3.2, + smoothness=6.1, ) prior = gpx.Prior(mean_function=gpx.Zero(), kernel=true_kernel) diff --git a/examples/kernels.pct.py b/examples/kernels.pct.py index b0867210..4daa866d 100644 --- a/examples/kernels.pct.py +++ b/examples/kernels.pct.py @@ -19,8 +19,10 @@ from simple_pytree import static_field import numpy as np -import gpjax as gpx -from gpjax.base.param import param_field +from jaxtyping import install_import_hook +with install_import_hook("gpjax", "beartype.beartype"): + import gpjax as gpx + from gpjax.base.param import param_field # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) diff --git a/examples/regression.pct.py b/examples/regression.pct.py index 03e713b5..e96bb82a 100644 --- a/examples/regression.pct.py +++ b/examples/regression.pct.py @@ -13,7 +13,9 @@ from jax import jit from jax.config import config -import gpjax as gpx +from jaxtyping import install_import_hook +with install_import_hook("gpjax", "beartype.beartype"): + import gpjax as gpx # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) diff --git a/examples/spatial.pct.py b/examples/spatial.pct.py index ee25862e..5e0b01ee 100644 --- a/examples/spatial.pct.py +++ b/examples/spatial.pct.py @@ -22,7 +22,6 @@ import fsspec import geopandas as gpd -import gpjax as gpx import jax import jax.numpy as jnp import jax.random as jr @@ -33,11 +32,15 @@ import pystac_client import rioxarray as rio import xarray as xr -from gpjax.base import param_field -from gpjax.dataset import Dataset from jaxtyping import Array, Float from rioxarray.merge import merge_arrays +from jaxtyping import install_import_hook +with install_import_hook("gpjax", "beartype.beartype"): + import gpjax as gpx + from gpjax.base import param_field + from gpjax.dataset import Dataset + jax.config.update("jax_enable_x64", True) key = jr.PRNGKey(123) diff --git a/examples/uncollapsed_vi.pct.py b/examples/uncollapsed_vi.pct.py index e2cf2340..3961a933 100644 --- a/examples/uncollapsed_vi.pct.py +++ b/examples/uncollapsed_vi.pct.py @@ -21,8 +21,10 @@ from jax import jit from jax.config import config -import gpjax as gpx -import gpjax.kernels as jk +from jaxtyping import install_import_hook +with install_import_hook("gpjax", "beartype.beartype"): + import gpjax as gpx + import gpjax.kernels as jk # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) diff --git a/examples/yacht.pct.py b/examples/yacht.pct.py index 6aa2f093..d323a2b3 100644 --- a/examples/yacht.pct.py +++ b/examples/yacht.pct.py @@ -10,7 +10,9 @@ from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler -import gpjax as gpx +from jaxtyping import install_import_hook +with install_import_hook("gpjax", "beartype.beartype"): + import gpjax as gpx # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) diff --git a/gpjax/base/module.py b/gpjax/base/module.py index b661b2e0..90af2a86 100644 --- a/gpjax/base/module.py +++ b/gpjax/base/module.py @@ -13,14 +13,13 @@ # limitations under the License. # ============================================================================== -from __future__ import annotations __all__ = ["Module", "meta_leaves", "meta_flatten", "meta_map", "meta"] import dataclasses import os from copy import copy, deepcopy -from typing import Any, Callable, Dict, Iterable, List, Tuple +from beartype.typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union import jax import jax.tree_util as jtu @@ -31,7 +30,9 @@ PyTreeCheckpointer, PyTreeCheckpointHandler, RestoreArgs, SaveArgs) from simple_pytree import Pytree, static_field -from typing_extensions import Self + + +Self = TypeVar('Self') class Module(Pytree): @@ -49,7 +50,7 @@ def __init_subclass__(cls, mutable: bool = False): ): cls._pytree__meta[field] = {**value.metadata} - def replace(self, **kwargs: Any) -> Self: + def replace(self: Self, **kwargs: Any) -> Self: """ Replace the values of the fields of the object. @@ -68,7 +69,7 @@ def replace(self, **kwargs: Any) -> Self: pytree.__dict__.update(kwargs) return pytree - def replace_meta(self, **kwargs: Any) -> Self: + def replace_meta(self: Self, **kwargs: Any) -> Self: """ Replace the metadata of the fields. @@ -87,7 +88,7 @@ def replace_meta(self, **kwargs: Any) -> Self: pytree.__dict__.update(_pytree__meta={**pytree._pytree__meta, **kwargs}) return pytree - def update_meta(self, **kwargs: Any) -> Self: + def update_meta(self: Self, **kwargs: Any) -> Self: """ Update the metadata of the fields. The metadata must already exist. @@ -112,15 +113,15 @@ def update_meta(self, **kwargs: Any) -> Self: pytree.__dict__.update(_pytree__meta=new) return pytree - def replace_trainable(self: Module, **kwargs: Dict[str, bool]) -> Self: + def replace_trainable(self: Self, **kwargs: Dict[str, bool]) -> Self: """Replace the trainability status of local nodes of the Module.""" return self.update_meta(**{k: {"trainable": v} for k, v in kwargs.items()}) - def replace_bijector(self: Module, **kwargs: Dict[str, tfb.Bijector]) -> Self: + def replace_bijector(self: Self, **kwargs: Dict[str, tfb.Bijector]) -> Self: """Replace the bijectors of local nodes of the Module.""" return self.update_meta(**{k: {"bijector": v} for k, v in kwargs.items()}) - def constrain(self) -> Self: + def constrain(self: Self) -> Self: """Transform model parameters to the constrained space according to their defined bijectors. Returns: @@ -137,7 +138,7 @@ def _apply_constrain(meta_leaf): return meta_map(_apply_constrain, self) - def unconstrain(self) -> Self: + def unconstrain(self: Self) -> Self: """Transform model parameters to the unconstrained space according to their defined bijectors. Returns: @@ -154,7 +155,7 @@ def _apply_unconstrain(meta_leaf): return meta_map(_apply_unconstrain, self) - def stop_gradient(self) -> Self: + def stop_gradient(self: Self) -> Self: """Stop gradients flowing through the Module. Returns: @@ -176,7 +177,7 @@ def _apply_stop_grad(meta_leaf): return meta_map(_apply_stop_grad, self) -def _toplevel_meta(pytree: Any) -> List[Dict[str, Any]]: +def _toplevel_meta(pytree: Any) -> List[Optional[Dict[str, Any]]]: """Unpacks a list of meta corresponding to the top-level nodes of the pytree. Args: @@ -197,8 +198,8 @@ def _toplevel_meta(pytree: Any) -> List[Dict[str, Any]]: def meta_leaves( pytree: Module, *, - is_leaf: Callable[[Any], bool] | None = None, -) -> List[Tuple[Dict[str, Any], Any]]: + is_leaf: Optional[Callable[[Any], bool]] = None, +) -> List[Tuple[Optional[Dict[str, Any]], Any]]: """ Returns the meta of the leaves of the pytree. @@ -212,8 +213,8 @@ def meta_leaves( def _unpack_metadata( meta_leaf: Any, - pytree: Module, - is_leaf: Callable[[Any], bool] | None, + pytree: Union[Module, Any], + is_leaf: Optional[Callable[[Any], bool]], ): """Recursively unpack leaf metadata.""" if is_leaf and is_leaf(pytree): @@ -235,8 +236,8 @@ def _unpack_metadata( def meta_flatten( - pytree: Module, *, is_leaf: Callable[[Any], bool] | None = None -) -> Module: + pytree: Union[Module, Any], *, is_leaf: Optional[Callable[[Any], bool]] = None +) -> Union[Module, Any]: """ Returns the meta of the Module. @@ -254,10 +255,10 @@ def meta_flatten( def meta_map( f: Callable[[Any, Dict[str, Any]], Any], - pytree: Module, + pytree: Union[Module, Any], *rest: Any, - is_leaf: Callable[[Any], bool] | None = None, -) -> Module: + is_leaf: Optional[Callable[[Any], bool]] = None, +) -> Union[Module, Any]: """Apply a function to a Module where the first argument are the pytree leaves, and the second argument are the Module metadata leaves. Args: f (Callable[[Any, Dict[str, Any]], Any]): The function to apply to the pytree. @@ -273,7 +274,7 @@ def meta_map( return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) -def meta(pytree: Module, *, is_leaf: Callable[[Any], bool] | None = None) -> Module: +def meta(pytree: Module, *, is_leaf: Optional[Callable[[Any], bool]] = None) -> Module: """Returns the metadata of the Module as a pytree. Args: diff --git a/gpjax/base/param.py b/gpjax/base/param.py index fecf3bf0..0354938b 100644 --- a/gpjax/base/param.py +++ b/gpjax/base/param.py @@ -13,12 +13,11 @@ # limitations under the License. # ============================================================================== -from __future__ import annotations __all__ = ["param_field"] import dataclasses -from typing import Any, Mapping, Optional +from beartype.typing import Any, Mapping, Optional import tensorflow_probability.substrates.jax.bijectors as tfb diff --git a/gpjax/dataset.py b/gpjax/dataset.py index 20427f80..96dfc6c9 100644 --- a/gpjax/dataset.py +++ b/gpjax/dataset.py @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from __future__ import annotations from dataclasses import dataclass -from typing import Optional +from beartype.typing import Optional, Union import jax.numpy as jnp -from jaxtyping import Array, Float +from jaxtyping import Float, Num from simple_pytree import Pytree - +from gpjax.typing import Array @dataclass class Dataset(Pytree): @@ -31,8 +30,8 @@ class Dataset(Pytree): y (Optional[Float[Array, "N Q"]]): Output data. """ - X: Optional[Float[Array, "N D"]] = None - y: Optional[Float[Array, "N Q"]] = None + X: Optional[Num[Array, "N D"]] = None + y: Optional[Num[Array, "N Q"]] = None def __post_init__(self) -> None: """Checks that the shapes of X and y are compatible.""" @@ -54,7 +53,7 @@ def is_unsupervised(self) -> bool: """Returns `True` if the dataset is unsupervised.""" return self.X is None and self.y is not None - def __add__(self, other: Dataset) -> Dataset: + def __add__(self, other: "Dataset") -> "Dataset": """Combine two datasets. Right hand dataset is stacked beneath the left.""" X = None @@ -84,7 +83,7 @@ def out_dim(self) -> int: return self.y.shape[1] -def _check_shape(X: Float[Array, "N D"], y: Float[Array, "N Q"]) -> None: +def _check_shape(X: Optional[Num[Array, "..."]], y: Optional[Num[Array, "..."]]) -> None: """Checks that the shapes of X and y are compatible.""" if X is not None and y is not None: if X.shape[0] != y.shape[0]: diff --git a/gpjax/fit.py b/gpjax/fit.py index 7cd2de75..367466ab 100644 --- a/gpjax/fit.py +++ b/gpjax/fit.py @@ -13,17 +13,18 @@ # limitations under the License. # ============================================================================== -from typing import Any, Optional, Tuple +from beartype.typing import Any, Optional, Tuple, Union, Callable import jax import jax.random as jr import optax as ox from jax._src.random import _check_prng_key -from jax.random import KeyArray -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float from jaxlib.xla_extension import PjitFunction from warnings import warn +from gpjax.typing import ScalarFloat, KeyArray from .base import Module from .dataset import Dataset from .objectives import AbstractObjective @@ -33,7 +34,7 @@ def fit( *, model: Module, - objective: AbstractObjective, + objective: Union[AbstractObjective, Callable[[Module, Dataset], ScalarFloat]], train_data: Dataset, optim: ox.GradientTransformation, num_iters: Optional[int] = 100, @@ -69,7 +70,7 @@ def fit( >>> model = LinearModel(weight=1.0, bias=1.0) >>> >>> # (3) Define your loss function: - >>> class MeanSqaureError(gpx.AbstractObjective): + >>> class MeanSquareError(gpx.AbstractObjective): ... def evaluate(self, model: LinearModel, train_data: gpx.Dataset) -> float: ... return jnp.mean((train_data.y - model(train_data.X)) ** 2) ... @@ -117,7 +118,7 @@ def fit( _check_verbose(verbose) # Unconstrained space loss function with stop-gradient rule for non-trainable params. - def loss(model: Module, batch: Dataset) -> Float[Array, "1"]: + def loss(model: Module, batch: Dataset) -> ScalarFloat: model = model.stop_gradient() return objective(model.constrain(), batch) diff --git a/gpjax/gaussian_distribution.py b/gpjax/gaussian_distribution.py index 367e3741..5a9e7556 100644 --- a/gpjax/gaussian_distribution.py +++ b/gpjax/gaussian_distribution.py @@ -13,13 +13,16 @@ # limitations under the License. # ============================================================================== -from typing import Any, Optional, Tuple + +from beartype.typing import Any, Optional, Tuple import jax.numpy as jnp import jax.random as jr +from gpjax.typing import KeyArray +from gpjax.typing import ScalarFloat from jax import vmap -from jax.random import KeyArray -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float import tensorflow_probability.substrates.jax as tfp from .linops import IdentityLinearOperator, LinearOperator @@ -132,20 +135,20 @@ def event_shape(self) -> Tuple: """Returns the event shape.""" return self.loc.shape[-1:] - def entropy(self) -> Float[Array, "1"]: + def entropy(self) -> ScalarFloat: """Calculates the entropy of the distribution.""" return 0.5 * ( self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi)) + self.scale.log_det() ) - def log_prob(self, y: Float[Array, "N"]) -> Float[Array, "1"]: + def log_prob(self, y: Float[Array, "N"]) -> ScalarFloat: """Calculates the log pdf of the multivariate Gaussian. Args: y (Float[Array, "N"]): The value to calculate the log probability of. Returns: - Float[Array, "1"]: The log probability of the value. + ScalarFloat: The log probability of the value. """ mu = self.loc sigma = self.scale @@ -179,11 +182,11 @@ def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]: return vmap(affine_transformation)(Z) - def sample(self,seed: KeyArray, sample_shape: Tuple[int, int]): # pylint: disable=useless-super-delegation - """See `Distribution.sample`.""" - return self._sample_n(seed, sample_shape[0]) + def sample(self, seed: KeyArray, sample_shape: Tuple[int, ...]): # pylint: disable=useless-super-delegation + """See `Distribution.sample`.""" + return self._sample_n(seed, sample_shape[0]) # TODO this looks weird, why ignore the second entry? - def kl_divergence(self, other: "GaussianDistribution") -> Float[Array, "1"]: + def kl_divergence(self, other: "GaussianDistribution") -> ScalarFloat: return _kl_divergence(self, other) @@ -200,14 +203,14 @@ def _check_and_return_dimension( return q.event_shape[-1] -def _frobeinius_norm_squared(matrix: Float[Array, "N N"]) -> Float[Array, "1"]: +def _frobenius_norm_squared(matrix: Float[Array, "N N"]) -> ScalarFloat: """Calculates the squared Frobenius norm of a matrix.""" return jnp.sum(jnp.square(matrix)) def _kl_divergence( q: GaussianDistribution, p: GaussianDistribution -) -> Float[Array, "1"]: +) -> ScalarFloat: """Computes the KL divergence, KL[q||p], between two multivariate Gaussian distributions q(x) = N(x; μq, Σq) and p(x) = N(x; μp, Σp). @@ -216,7 +219,7 @@ def _kl_divergence( p (GaussianDistribution): A multivariate Gaussian distribution. Returns: - Float[Array, "1"]: The KL divergence between q and p. + ScalarFloat: The KL divergence between q and p. """ n_dim = _check_and_return_dimension(q, p) @@ -237,14 +240,14 @@ def _kl_divergence( diff = mu_p - mu_q # trace term, tr[Σp⁻¹ Σq] = tr[(LpLpᵀ)⁻¹(LqLqᵀ)] = tr[(Lp⁻¹Lq)(Lp⁻¹Lq)ᵀ] = (fr[LqLp⁻¹])² - trace = _frobeinius_norm_squared( + trace = _frobenius_norm_squared( sqrt_p.solve(sqrt_q.to_dense()) ) # TODO: Not most efficient, given the `to_dense()` call (e.g., consider diagonal p and q). Need to abstract solving linear operator against another linear operator. # Mahalanobis term, (μp - μq)ᵀ Σp⁻¹ (μp - μq) = tr [(μp - μq)ᵀ [LpLpᵀ]⁻¹ (μp - μq)] = (fr[Lp⁻¹(μp - μq)])² - mahalanobis = _frobeinius_norm_squared( + mahalanobis = jnp.sum(jnp.square( sqrt_p.solve(diff) - ) # TODO: Need to improve this. Perhaps add a Mahalanobis method to ``LinearOperator``s. + )) # TODO: Need to improve this. Perhaps add a Mahalanobis method to ``LinearOperator``s. # KL[q(x)||p(x)] = [ [(μp - μq)ᵀ Σp⁻¹ (μp - μq)] - n - log|Σq| + log|Σp| + tr[Σp⁻¹ Σq] ] / 2 return (mahalanobis - n_dim - sigma_q.log_det() + sigma_p.log_det() + trace) / 2.0 diff --git a/gpjax/gps.py b/gpjax/gps.py index 11bb7dea..55f25b8e 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -15,11 +15,13 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Optional, Dict +from beartype.typing import Any, Callable, Dict, Optional import jax.numpy as jnp -from jax.random import KeyArray, PRNGKey, normal -from jaxtyping import Array, Float +from jax.random import PRNGKey, normal +from jaxtyping import Float, Num +from gpjax.typing import KeyArray, Array + from simple_pytree import static_field from .base import Module, param_field @@ -96,7 +98,7 @@ class Prior(AbstractPrior): .. math:: - p(f(\\cdot)) = \mathcal{GP}(m(\\cdot), k(\\cdot, \\cdot)). + p(f(\\cdot)) = \\mathcal{GP}(m(\\cdot), k(\\cdot, \\cdot)). To invoke a ``Prior`` distribution, only a kernel function is required. By default, the mean function will be set to zero. In general, this assumption @@ -156,7 +158,7 @@ def __rmul__(self, other: AbstractLikelihood): """ return self.__mul__(other) - def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: + def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution: """Compute the predictive prior distribution for a given set of parameters. The output of this function is a function that computes a TFP distribution for a given set of inputs. @@ -315,27 +317,27 @@ def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: @dataclass class ConjugatePosterior(AbstractPosterior): - """A Gaussian process posterior distribution when the constituent likelihood + r"""A Gaussian process posterior distribution when the constituent likelihood function is a Gaussian distribution. In such cases, the latent function values :math:`f` can be analytically integrated out of the posterior distribution. As such, many computational operations can be simplified; something we make use of in this object. For a Gaussian process prior :math:`p(\mathbf{f})` and a Gaussian likelihood - :math:`p(y | \\mathbf{f}) = \\mathcal{N}(y\\mid \mathbf{f}, \\sigma^2))` where - :math:`\mathbf{f} = f(\\mathbf{x})`, the predictive posterior distribution at - a set of inputs :math:`\\mathbf{x}` is given by + :math:`p(y | \mathbf{f}) = \mathcal{N}(y\mid \mathbf{f}, \sigma^2))` where + :math:`\mathbf{f} = f(\mathbf{x})`, the predictive posterior distribution at + a set of inputs :math:`\mathbf{x}` is given by .. math:: - p(\\mathbf{f}^{\\star}\mid \mathbf{y}) & = \\int p(\\mathbf{f}^{\\star} \\mathbf{f} \\mid \\mathbf{y})\\\\ - & =\\mathcal{N}(\\mathbf{f}^{\\star} \\boldsymbol{\mu}_{\mid \mathbf{y}}, \\boldsymbol{\Sigma}_{\mid \mathbf{y}} + p(\mathbf{f}^{\star}\mid \mathbf{y}) & = \int p(\mathbf{f}^{\star} \mathbf{f} \mid \mathbf{y})\\ + & =\mathcal{N}(\mathbf{f}^{\star} \boldsymbol{\mu}_{\mid \mathbf{y}}, \boldsymbol{\Sigma}_{\mid \mathbf{y}} where .. math:: - \\boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\\mathbf{x}^{\\star}, \\mathbf{x})\\left(k(\\mathbf{x}, \\mathbf{x}')+\\sigma^2\\mathbf{I}_n\\right)^{-1}\\mathbf{y} \\\\ - \\boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\\mathbf{x}^{\\star}, \\mathbf{x}^{\\star\\prime}) -k(\\mathbf{x}^{\\star}, \\mathbf{x})\\left( k(\\mathbf{x}, \\mathbf{x}') + \\sigma^2\\mathbf{I}_n \\right)^{-1}k(\\mathbf{x}, \\mathbf{x}^{\\star}). + \boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\mathbf{x}^{\star}, \mathbf{x})\left(k(\mathbf{x}, \mathbf{x}')+\sigma^2\mathbf{I}_n\right)^{-1}\mathbf{y} \\ + \boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\mathbf{x}^{\star}, \mathbf{x}^{\star\prime}) -k(\mathbf{x}^{\star}, \mathbf{x})\left( k(\mathbf{x}, \mathbf{x}') + \sigma^2\mathbf{I}_n \right)^{-1}k(\mathbf{x}, \mathbf{x}^{\star}). Example: >>> import gpjax as gpx @@ -349,10 +351,10 @@ class ConjugatePosterior(AbstractPosterior): def predict( self, - test_inputs: Float[Array, "N D"], + test_inputs: Num[Array, "N D"], train_data: Dataset, ) -> GaussianDistribution: - """Conditional on a training data set, compute the GP's posterior + r"""Conditional on a training data set, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density. @@ -360,14 +362,14 @@ def predict( The predictive distribution of a conjugate GP is given by .. math:: - p(\\mathbf{f}^{\\star}\mid \mathbf{y}) & = \\int p(\\mathbf{f}^{\\star} \\mathbf{f} \\mid \\mathbf{y})\\\\ - & =\\mathcal{N}(\\mathbf{f}^{\\star} \\boldsymbol{\mu}_{\mid \mathbf{y}}, \\boldsymbol{\Sigma}_{\mid \mathbf{y}} + p(\mathbf{f}^{\star}\mid \mathbf{y}) & = \int p(\mathbf{f}^{\star} \mathbf{f} \mid \mathbf{y})\\ + & =\mathcal{N}(\mathbf{f}^{\star} \boldsymbol{\mu}_{\mid \mathbf{y}}, \boldsymbol{\Sigma}_{\mid \mathbf{y}} where .. math:: - \\boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\\mathbf{x}^{\\star}, \\mathbf{x})\\left(k(\\mathbf{x}, \\mathbf{x}')+\\sigma^2\\mathbf{I}_n\\right)^{-1}\\mathbf{y} \\\\ - \\boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\\mathbf{x}^{\\star}, \\mathbf{x}^{\\star\\prime}) -k(\\mathbf{x}^{\\star}, \\mathbf{x})\\left( k(\\mathbf{x}, \\mathbf{x}') + \\sigma^2\\mathbf{I}_n \\right)^{-1}k(\\mathbf{x}, \\mathbf{x}^{\\star}). + \boldsymbol{\mu}_{\mid \mathbf{y}} & = k(\mathbf{x}^{\star}, \mathbf{x})\left(k(\mathbf{x}, \mathbf{x}')+\sigma^2\mathbf{I}_n\right)^{-1}\mathbf{y} \\ + \boldsymbol{\Sigma}_{\mid \mathbf{y}} & =k(\mathbf{x}^{\star}, \mathbf{x}^{\star\prime}) -k(\mathbf{x}^{\star}, \mathbf{x})\left( k(\mathbf{x}, \mathbf{x}') + \sigma^2\mathbf{I}_n \right)^{-1}k(\mathbf{x}, \mathbf{x}^{\star}). The conditioning set is a GPJax ``Dataset`` object, whilst predictions are made on a regular Jax array. @@ -546,7 +548,7 @@ def __post_init__(self): self.latent = normal(self.key, shape=(self.likelihood.num_datapoints, 1)) def predict( - self, test_inputs: Float[Array, "N D"], train_data: Dataset + self, test_inputs: Num[Array, "N D"], train_data: Dataset ) -> GaussianDistribution: """ Conditional on a set of training data, compute the GP's posterior diff --git a/gpjax/kernels/__init__.py b/gpjax/kernels/__init__.py index 3cbfeba7..c19bd4ad 100644 --- a/gpjax/kernels/__init__.py +++ b/gpjax/kernels/__init__.py @@ -14,8 +14,9 @@ # ============================================================================== """JaxKern.""" -from .approximations import RFF from .base import AbstractKernel, ProductKernel, SumKernel + +from .approximations import RFF from .computations import (BasisFunctionComputation, ConstantDiagonalKernelComputation, DenseKernelComputation, DiagonalKernelComputation, diff --git a/gpjax/kernels/approximations/rff.py b/gpjax/kernels/approximations/rff.py index 93897487..a2c88ef4 100644 --- a/gpjax/kernels/approximations/rff.py +++ b/gpjax/kernels/approximations/rff.py @@ -1,13 +1,15 @@ from dataclasses import dataclass import tensorflow_probability.substrates.jax.bijectors as tfb -from jax.random import KeyArray, PRNGKey -from jaxtyping import Array, Float +from jax.random import PRNGKey +from gpjax.typing import Array +from jaxtyping import Float from simple_pytree import static_field from ...base import param_field from ..base import AbstractKernel from ..computations import BasisFunctionComputation +from gpjax.typing import KeyArray @dataclass @@ -31,7 +33,7 @@ class RFF(AbstractKernel): """ base_kernel: AbstractKernel = None num_basis_fns: int = static_field(50) - frequencies: Float[Array, "M 1"] = param_field(None, bijector=tfb.Identity) + frequencies: Float[Array, "M 1"] = param_field(None, bijector=tfb.Identity()) key: KeyArray = static_field(PRNGKey(123)) def __post_init__(self) -> None: @@ -78,4 +80,4 @@ def compute_features(self, x: Float[Array, "N D"]) -> Float[Array, "N L"]: Returns: Float[Array, "N L"]: A N x L array of features where L = 2M. """ - return self.compute_engine(self).compute_features(x) \ No newline at end of file + return self.compute_engine(self).compute_features(x) diff --git a/gpjax/kernels/base.py b/gpjax/kernels/base.py index 202eb5d2..a0e2834b 100644 --- a/gpjax/kernels/base.py +++ b/gpjax/kernels/base.py @@ -13,17 +13,23 @@ # limitations under the License. # ============================================================================== -from __future__ import annotations import abc +from beartype.typing import Callable, Dict, List, Optional, Sequence, Type, Union from dataclasses import dataclass from functools import partial -from typing import Callable, List, Union import jax.numpy as jnp -from jaxtyping import Array, Float +import jax.random +import jax +from gpjax.typing import KeyArray +from gpjax.typing import ScalarFloat +from gpjax.typing import Array +from jaxtyping import Float, Num from simple_pytree import static_field +import tensorflow_probability.substrates.jax.distributions as tfd +from gpjax.typing import ScalarFloat from ..base import Module, param_field from .computations import AbstractKernelComputation, DenseKernelComputation @@ -32,36 +38,36 @@ class AbstractKernel(Module): """Base kernel class.""" - compute_engine: AbstractKernelComputation = static_field(DenseKernelComputation) - active_dims: List[int] = static_field(None) + compute_engine: Type[AbstractKernelComputation] = static_field(DenseKernelComputation) + active_dims: Optional[List[int]] = static_field(None) name: str = static_field("AbstractKernel") @property def ndims(self): return 1 if not self.active_dims else len(self.active_dims) - def cross_covariance(self, x: Float[Array, "N D"], y: Float[Array, "M D"]): + def cross_covariance(self, x: Num[Array, "N D"], y: Num[Array, "M D"]): return self.compute_engine(self).cross_covariance(x, y) - def gram(self, x: Float[Array, "N D"]): + def gram(self, x: Num[Array, "N D"]): return self.compute_engine(self).gram(x) - def slice_input(self, x: Float[Array, "N D"]) -> Float[Array, "N S"]: + def slice_input(self, x: Float[Array, "... D"]) -> Float[Array, "... Q"]: """Select the relevant columns of the supplied matrix to be used within the kernel's evaluation. Args: - x (Float[Array, "N D"]): The matrix or vector that is to be sliced. + x (Float[Array, "... D"]): The matrix or vector that is to be sliced. Returns: - Float[Array, "N S"]: A sliced form of the input matrix. + Float[Array, "... Q"]: A sliced form of the input matrix. """ - return x[..., self.active_dims] + return x[..., self.active_dims] if self.active_dims is not None else x @abc.abstractmethod def __call__( self, x: Float[Array, "D"], y: Float[Array, "D"], - ) -> Float[Array, "1"]: + ) -> ScalarFloat: """Evaluate the kernel on a pair of inputs. Args: @@ -69,13 +75,13 @@ def __call__( y (Float[Array, "D"]): The right hand input of the kernel function. Returns: - Float[Array, "1"]: The evaluated kernel function at the supplied inputs. + ScalarFloat: The evaluated kernel function at the supplied inputs. """ raise NotImplementedError def __add__( - self, other: Union[AbstractKernel, Float[Array, "1"]] - ) -> AbstractKernel: + self, other: Union["AbstractKernel", ScalarFloat] + ) -> "AbstractKernel": """Add two kernels together. Args: other (AbstractKernel): The kernel to be added to the current kernel. @@ -86,12 +92,12 @@ def __add__( if isinstance(other, AbstractKernel): return SumKernel(kernels=[self, other]) - - return SumKernel(kernels=[self, Constant(other)]) + else: + return SumKernel(kernels=[self, Constant(other)]) def __radd__( - self, other: Union[AbstractKernel, Float[Array, "1"]] - ) -> AbstractKernel: + self, other: Union["AbstractKernel", ScalarFloat] + ) -> "AbstractKernel": """Add two kernels together. Args: other (AbstractKernel): The kernel to be added to the current kernel. @@ -102,8 +108,8 @@ def __radd__( return self.__add__(other) def __mul__( - self, other: Union[AbstractKernel, Float[Array, "1"]] - ) -> AbstractKernel: + self, other: Union["AbstractKernel", ScalarFloat] + ) -> "AbstractKernel": """Multiply two kernels together. Args: @@ -114,24 +120,24 @@ def __mul__( """ if isinstance(other, AbstractKernel): return ProductKernel(kernels=[self, other]) - - return ProductKernel(kernels=[self, Constant(other)]) + else: + return ProductKernel(kernels=[self, Constant(other)]) @property - def spectral_density(self) -> tfd.Distribution: + def spectral_density(self) -> Optional[tfd.Distribution]: return None @dataclass class Constant(AbstractKernel): """ - A constant mean function. This function returns a repeated scalar value for all inputs. + A constant kernel. This kernel evaluates to a constant for all inputs. The scalar value itself can be treated as a model hyperparameter and learned during training. """ - constant: Float[Array, "1"] = param_field(jnp.array(0.0)) + constant: ScalarFloat = param_field(jnp.array(0.0)) - def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> ScalarFloat: """Evaluate the kernel on a pair of inputs. Args: @@ -139,7 +145,7 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, " y (Float[Array, "D"]): The right hand input of the kernel function. Returns: - Float[Array, "1"]: The evaluated kernel function at the supplied inputs. + ScalarFloat: The evaluated kernel function at the supplied inputs. """ return self.constant.squeeze() @@ -170,7 +176,7 @@ def __call__( self, x: Float[Array, "D"], y: Float[Array, "D"], - ) -> Float[Array, "1"]: + ) -> ScalarFloat: """Evaluate the kernel on a pair of inputs. Args: @@ -178,7 +184,7 @@ def __call__( y (Float[Array, "D"]): The right hand input of the kernel function. Returns: - Float[Array, "1"]: The evaluated kernel function at the supplied inputs. + ScalarFloat: The evaluated kernel function at the supplied inputs. """ return self.operator(jnp.stack([k(x, y) for k in self.kernels])) diff --git a/gpjax/kernels/computations/base.py b/gpjax/kernels/computations/base.py index a2628450..227fcbfd 100644 --- a/gpjax/kernels/computations/base.py +++ b/gpjax/kernels/computations/base.py @@ -15,26 +15,25 @@ import abc from dataclasses import dataclass -from typing import Any +from beartype.typing import Any from jax import vmap -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float, Num from gpjax.linops import (DenseLinearOperator, DiagonalLinearOperator, LinearOperator) -Kernel = Any - @dataclass class AbstractKernelComputation: """Abstract class for kernel computations.""" - kernel: Kernel + kernel: "gpjax.kernels.base.AbstractKernel" def gram( self, - x: Float[Array, "N D"], + x: Num[Array, "N D"], ) -> LinearOperator: """Compute Gram covariance operator of the kernel function. @@ -49,7 +48,7 @@ def gram( @abc.abstractmethod def cross_covariance( - self, x: Float[Array, "N D"], y: Float[Array, "M D"] + self, x: Num[Array, "N D"], y: Num[Array, "M D"] ) -> Float[Array, "N M"]: """For a given kernel, compute the NxM gram matrix on an a pair of input matrices with shape NxD and MxD. @@ -63,7 +62,7 @@ def cross_covariance( """ raise NotImplementedError - def diagonal(self, inputs: Float[Array, "N D"]) -> DiagonalLinearOperator: + def diagonal(self, inputs: Num[Array, "N D"]) -> DiagonalLinearOperator: """For a given kernel, compute the elementwise diagonal of the NxN gram matrix on an input matrix of shape NxD. diff --git a/gpjax/kernels/computations/basis_functions.py b/gpjax/kernels/computations/basis_functions.py index 39aa44fc..8ea4a327 100644 --- a/gpjax/kernels/computations/basis_functions.py +++ b/gpjax/kernels/computations/basis_functions.py @@ -1,7 +1,8 @@ from dataclasses import dataclass import jax.numpy as jnp -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float from gpjax.linops import DenseLinearOperator @@ -19,7 +20,6 @@ def cross_covariance( ) -> Float[Array, "N M"]: """For a pair of inputs, compute the cross covariance matrix between the inputs. Args: - params (Dict): A dictionary of parameters for which the cross-covariance matrix should be constructed with. x: A N x D array of inputs. y: A M x D array of inputs. @@ -35,7 +35,6 @@ def gram(self, inputs: Float[Array, "N D"]) -> DenseLinearOperator: """For the Gram matrix, we can save computations by computing only one matrix multiplication between the inputs and the scaled frequencies. Args: - params (Dict): A dictionary of parameters for which the Gram matrix should be constructed with. inputs: A N x D array of inputs. Returns: diff --git a/gpjax/kernels/computations/constant_diagonal.py b/gpjax/kernels/computations/constant_diagonal.py index 3e4baed2..f98c382f 100644 --- a/gpjax/kernels/computations/constant_diagonal.py +++ b/gpjax/kernels/computations/constant_diagonal.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + import jax.numpy as jnp from jax import vmap -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float from gpjax.linops import ConstantDiagonalLinearOperator, DiagonalLinearOperator @@ -41,7 +43,6 @@ def diagonal(self, inputs: Float[Array, "N D"]) -> DiagonalLinearOperator: Args: kernel (AbstractKernel): The kernel for which the variance vector should be computed for. - params (Dict): The kernel's parameter set. inputs (Float[Array, "N D"]): The input matrix. Returns: diff --git a/gpjax/kernels/computations/dense.py b/gpjax/kernels/computations/dense.py index bc036248..370280ca 100644 --- a/gpjax/kernels/computations/dense.py +++ b/gpjax/kernels/computations/dense.py @@ -14,7 +14,8 @@ # ============================================================================== from jax import vmap -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float from .base import AbstractKernelComputation diff --git a/gpjax/kernels/computations/diagonal.py b/gpjax/kernels/computations/diagonal.py index 2f9eaf23..631d6f9a 100644 --- a/gpjax/kernels/computations/diagonal.py +++ b/gpjax/kernels/computations/diagonal.py @@ -14,7 +14,8 @@ # ============================================================================== from jax import vmap -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float from gpjax.kernels.computations.base import AbstractKernelComputation from gpjax.linops import DiagonalLinearOperator diff --git a/gpjax/kernels/computations/eigen.py b/gpjax/kernels/computations/eigen.py index c90f53c6..9b32ea57 100644 --- a/gpjax/kernels/computations/eigen.py +++ b/gpjax/kernels/computations/eigen.py @@ -13,11 +13,12 @@ # limitations under the License. # ============================================================================== + from dataclasses import dataclass -from typing import Dict import jax.numpy as jnp -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float, Num from .base import AbstractKernelComputation @@ -25,7 +26,7 @@ @dataclass class EigenKernelComputation(AbstractKernelComputation): def cross_covariance( - self, x: Float[Array, "N D"], y: Float[Array, "M D"] + self, x: Num[Array, "N D"], y: Num[Array, "M D"] ) -> Float[Array, "N M"]: # Transform the eigenvalues of the graph Laplacian according to the # RBF kernel's SPDE form. diff --git a/gpjax/kernels/non_euclidean/graph.py b/gpjax/kernels/non_euclidean/graph.py index 455eee9b..ce8cb0ca 100644 --- a/gpjax/kernels/non_euclidean/graph.py +++ b/gpjax/kernels/non_euclidean/graph.py @@ -13,12 +13,15 @@ # limitations under the License. # ============================================================================== + from dataclasses import dataclass import jax.numpy as jnp import tensorflow_probability.substrates.jax as tfp -from jaxtyping import Array, Float, Int +from gpjax.typing import Array +from jaxtyping import Float, Num, Int from simple_pytree import static_field +from gpjax.typing import ScalarFloat, ScalarInt from ...base import param_field from ..base import AbstractKernel @@ -40,17 +43,13 @@ class GraphKernel(AbstractKernel): compute_engine """ - laplacian: Float[Array, "N N"] = static_field(None) - lengthscale: Float[Array, "D"] = param_field( - jnp.array([1.0]), bijector=tfb.Softplus() - ) - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) - smoothness: Float[Array, "1"] = param_field( - jnp.array([1.0]), bijector=tfb.Softplus() - ) + laplacian: Num[Array, "N N"] = static_field(None) + lengthscale: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) + variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) + smoothness: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) eigenvalues: Float[Array, "N"] = static_field(None) eigenvectors: Float[Array, "N N"] = static_field(None) - num_vertex: Int[Array, "1"] = static_field(None) + num_vertex: ScalarInt = static_field(None) compute_engine: AbstractKernelComputation = static_field(EigenKernelComputation) name: str = "Graph Matérn" @@ -63,22 +62,23 @@ def __post_init__(self): if self.num_vertex is None: self.num_vertex = self.eigenvalues.shape[0] - def __call__( + def __call__( # TODO not consistent with general kernel interface self, - x: Float[Array, "1 D"], - y: Float[Array, "1 D"], + x: Int[Array, "N 1"], + y: Int[Array, "N 1"], + *, + S, **kwargs, - ) -> Float[Array, "1"]: + ): """Evaluate the graph kernel on a pair of vertices :math:`v_i, v_j`. Args: - x (Float[Array, "1 D"]): Index of the ith vertex. - y (Float[Array, "1 D"]): Index of the jth vertex. + x (Float[Array, "N 1"]): Index of the ith vertex. + y (Float[Array, "N 1"]): Index of the jth vertex. Returns: - Float[Array, "1"]: The value of :math:`k(v_i, v_j)`. + ScalarFloat: The value of :math:`k(v_i, v_j)`. """ - S = kwargs["S"] Kxx = (jax_gather_nd(self.eigenvectors, x) * S.squeeze()) @ jnp.transpose( jax_gather_nd(self.eigenvectors, y) ) # shape (n,n) diff --git a/gpjax/kernels/non_euclidean/utils.py b/gpjax/kernels/non_euclidean/utils.py index c247b3d4..24efbc18 100644 --- a/gpjax/kernels/non_euclidean/utils.py +++ b/gpjax/kernels/non_euclidean/utils.py @@ -13,12 +13,13 @@ # limitations under the License. # ============================================================================== -from jaxtyping import Array, Int, Num +from gpjax.typing import Array +from jaxtyping import Int, Num def jax_gather_nd( - params: Num[Array, "N ..."], indices: Int[Array, "M"] -) -> Num[Array, "M ..."]: + params: Num[Array, "N *rest"], indices: Int[Array, "M 1"] +) -> Num[Array, "M *rest"]: """Slice a `params` array at a set of `indices`. Args: diff --git a/gpjax/kernels/nonstationary/linear.py b/gpjax/kernels/nonstationary/linear.py index cece303b..0dad23d7 100644 --- a/gpjax/kernels/nonstationary/linear.py +++ b/gpjax/kernels/nonstationary/linear.py @@ -17,7 +17,9 @@ import jax.numpy as jnp import tensorflow_probability.substrates.jax.bijectors as tfb -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float +from gpjax.typing import ScalarFloat from ...base import param_field from ..base import AbstractKernel @@ -27,14 +29,14 @@ class Linear(AbstractKernel): """The linear kernel.""" - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) name: str = "Linear" def __call__( self, x: Float[Array, "D"], y: Float[Array, "D"], - ) -> Float[Array, "1"]: + ) -> ScalarFloat: """Evaluate the linear kernel on a pair of inputs :math:`(x, y)` with variance parameter :math:`\\sigma` .. math:: @@ -45,7 +47,7 @@ def __call__( y (Float[Array, "D"]): The right hand input of the kernel function. Returns: - Float[Array, "1"]: The evaluated kernel function :math:`k(x, y)` at the supplied inputs. + ScalarFloat: The evaluated kernel function :math:`k(x, y)` at the supplied inputs. """ x = self.slice_input(x) y = self.slice_input(y) diff --git a/gpjax/kernels/nonstationary/polynomial.py b/gpjax/kernels/nonstationary/polynomial.py index 1a47932e..0c250043 100644 --- a/gpjax/kernels/nonstationary/polynomial.py +++ b/gpjax/kernels/nonstationary/polynomial.py @@ -17,8 +17,10 @@ import jax.numpy as jnp import tensorflow_probability.substrates.jax.bijectors as tfb -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float from simple_pytree import static_field +from gpjax.typing import ScalarFloat, ScalarInt from ...base import param_field from ..base import AbstractKernel @@ -28,14 +30,14 @@ class Polynomial(AbstractKernel): """The Polynomial kernel with variable degree.""" - degree: int = static_field(2) - shift: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + degree: ScalarInt = static_field(2) + shift: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) + variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) def __post_init__(self): self.name = f"Polynomial (degree {self.degree})" - def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> ScalarFloat: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with shift parameter :math:`\\alpha` and variance :math:`\\sigma^2` through @@ -49,9 +51,9 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, " call Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. + ScalarFloat: The value of :math:`k(x, y)`. """ - x = self.slice_input(x).squeeze() - y = self.slice_input(y).squeeze() - K = jnp.power(self.shift + jnp.dot(x * self.variance, y), self.degree) + x = self.slice_input(x) + y = self.slice_input(y) + K = jnp.power(self.shift + self.variance * jnp.dot(x, y), self.degree) return K.squeeze() diff --git a/gpjax/kernels/stationary/matern12.py b/gpjax/kernels/stationary/matern12.py index b820ec78..0e7a25ec 100644 --- a/gpjax/kernels/stationary/matern12.py +++ b/gpjax/kernels/stationary/matern12.py @@ -18,7 +18,10 @@ import jax.numpy as jnp import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float +from beartype.typing import Union +from gpjax.typing import ScalarFloat from ...base import param_field from ..base import AbstractKernel @@ -29,13 +32,13 @@ class Matern12(AbstractKernel): """The Matérn kernel with smoothness parameter fixed at 0.5.""" - lengthscale: Float[Array, "D"] = param_field( - jnp.array([1.0]), bijector=tfb.Softplus() + lengthscale: Union[ScalarFloat, Float[Array, "D"]] = param_field( + jnp.array(1.0), bijector=tfb.Softplus() ) - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) name: str = "Matérn12" - def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> ScalarFloat: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` @@ -46,7 +49,7 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, " x (Float[Array, "D"]): The left hand argument of the kernel function's call. y (Float[Array, "D"]): The right hand argument of the kernel function's call Returns: - Float[Array, "1"]: The value of :math:`k(x, y)` + ScalarFloat: The value of :math:`k(x, y)` """ x = self.slice_input(x) / self.lengthscale y = self.slice_input(y) / self.lengthscale diff --git a/gpjax/kernels/stationary/matern32.py b/gpjax/kernels/stationary/matern32.py index 9d22fb1f..e5ba3561 100644 --- a/gpjax/kernels/stationary/matern32.py +++ b/gpjax/kernels/stationary/matern32.py @@ -18,7 +18,10 @@ import jax.numpy as jnp import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float +from beartype.typing import Union +from gpjax.typing import ScalarFloat from ...base import param_field from ..base import AbstractKernel @@ -29,17 +32,17 @@ class Matern32(AbstractKernel): """The Matérn kernel with smoothness parameter fixed at 1.5.""" - lengthscale: Float[Array, "D"] = param_field( - jnp.array([1.0]), bijector=tfb.Softplus() + lengthscale: Union[ScalarFloat, Float[Array, "D"]] = param_field( + jnp.array(1.0), bijector=tfb.Softplus() ) - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) name: str = "Matérn32" def __call__( self, x: Float[Array, "D"], y: Float[Array, "D"], - ) -> Float[Array, "1"]: + ) -> ScalarFloat: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` @@ -51,7 +54,7 @@ def __call__( y (Float[Array, "D"]): The right hand argument of the kernel function's call. Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. + ScalarFloat: The value of :math:`k(x, y)`. """ x = self.slice_input(x) / self.lengthscale y = self.slice_input(y) / self.lengthscale diff --git a/gpjax/kernels/stationary/matern52.py b/gpjax/kernels/stationary/matern52.py index 64b6be75..d810c082 100644 --- a/gpjax/kernels/stationary/matern52.py +++ b/gpjax/kernels/stationary/matern52.py @@ -18,7 +18,10 @@ import jax.numpy as jnp import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float +from beartype.typing import Union +from gpjax.typing import ScalarFloat from ...base import param_field from ..base import AbstractKernel @@ -29,13 +32,13 @@ class Matern52(AbstractKernel): """The Matérn kernel with smoothness parameter fixed at 2.5.""" - lengthscale: Float[Array, "D"] = param_field( - jnp.array([1.0]), bijector=tfb.Softplus() + lengthscale: Union[ScalarFloat, Float[Array, "D"]] = param_field( + jnp.array(1.0), bijector=tfb.Softplus() ) - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) name: str = "Matérn52" - def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> ScalarFloat: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` @@ -47,7 +50,7 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, " y (Float[Array, "D"]): The right hand argument of the kernel function's call. Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. + ScalarFloat: The value of :math:`k(x, y)`. """ x = self.slice_input(x) / self.lengthscale y = self.slice_input(y) / self.lengthscale diff --git a/gpjax/kernels/stationary/periodic.py b/gpjax/kernels/stationary/periodic.py index 4c744ff0..650f8247 100644 --- a/gpjax/kernels/stationary/periodic.py +++ b/gpjax/kernels/stationary/periodic.py @@ -14,11 +14,14 @@ # ============================================================================== from dataclasses import dataclass +from beartype.typing import Union import jax.numpy as jnp import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float +from gpjax.typing import ScalarFloat from ...base import param_field from ..base import AbstractKernel @@ -31,14 +34,14 @@ class Periodic(AbstractKernel): Key reference is MacKay 1998 - "Introduction to Gaussian processes". """ - lengthscale: Float[Array, "D"] = param_field( + lengthscale: Union[ScalarFloat, Float[Array, "D"]] = param_field( jnp.array([1.0]), bijector=tfb.Softplus() ) - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) - period: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) + period: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) name: str = "Periodic" - def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> ScalarFloat: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma` TODO: update docstring @@ -50,7 +53,7 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, " x (Float[Array, "D"]): The left hand argument of the kernel function's call. y (Float[Array, "D"]): The right hand argument of the kernel function's call Returns: - Float[Array, "1"]: The value of :math:`k(x, y)` + ScalarFloat: The value of :math:`k(x, y)` """ x = self.slice_input(x) y = self.slice_input(y) diff --git a/gpjax/kernels/stationary/powered_exponential.py b/gpjax/kernels/stationary/powered_exponential.py index dd1c1903..b6474dfe 100644 --- a/gpjax/kernels/stationary/powered_exponential.py +++ b/gpjax/kernels/stationary/powered_exponential.py @@ -14,11 +14,14 @@ # ============================================================================== from dataclasses import dataclass +from beartype.typing import Union import jax.numpy as jnp import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float +from gpjax.typing import ScalarFloat from ...base import param_field from ..base import AbstractKernel @@ -33,14 +36,14 @@ class PoweredExponential(AbstractKernel): """ - lengthscale: Float[Array, "D"] = param_field( + lengthscale: Union[ScalarFloat, Float[Array, "D"]] = param_field( jnp.array([1.0]), bijector=tfb.Softplus() ) - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) - power: Float[Array, "1"] = param_field(jnp.array([1.0])) + variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) + power: ScalarFloat = param_field(jnp.array(1.0)) name: str = "Powered Exponential" - def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> ScalarFloat: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell`, :math:`\\sigma` and power :math:`\\kappa`. .. math:: @@ -51,7 +54,7 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, " y (Float[Array, "D"]): The right hand argument of the kernel function's call Returns: - Float[Array, "1"]: The value of :math:`k(x, y)` + ScalarFloat: The value of :math:`k(x, y)` """ x = self.slice_input(x) / self.lengthscale y = self.slice_input(y) / self.lengthscale diff --git a/gpjax/kernels/stationary/rational_quadratic.py b/gpjax/kernels/stationary/rational_quadratic.py index 2b08296f..44217e53 100644 --- a/gpjax/kernels/stationary/rational_quadratic.py +++ b/gpjax/kernels/stationary/rational_quadratic.py @@ -18,7 +18,10 @@ import jax.numpy as jnp import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd -from jaxtyping import Array, Float +from beartype.typing import Union +from gpjax.typing import Array +from jaxtyping import Float +from gpjax.typing import ScalarFloat from ...base import param_field from ..base import AbstractKernel @@ -27,14 +30,14 @@ @dataclass class RationalQuadratic(AbstractKernel): - lengthscale: Float[Array, "D"] = param_field( + lengthscale: Union[ScalarFloat, Float[Array, "D"]] = param_field( jnp.array([1.0]), bijector=tfb.Softplus() ) - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) - alpha: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) + alpha: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) name: str = "Rational Quadratic" - def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> ScalarFloat: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma` .. math:: @@ -44,7 +47,7 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, " x (Float[Array, "D"]): The left hand argument of the kernel function's call. y (Float[Array, "D"]): The right hand argument of the kernel function's call Returns: - Float[Array, "1"]: The value of :math:`k(x, y)` + ScalarFloat: The value of :math:`k(x, y)` """ x = self.slice_input(x) / self.lengthscale y = self.slice_input(y) / self.lengthscale diff --git a/gpjax/kernels/stationary/rbf.py b/gpjax/kernels/stationary/rbf.py index 964f8015..6ce74cde 100644 --- a/gpjax/kernels/stationary/rbf.py +++ b/gpjax/kernels/stationary/rbf.py @@ -15,10 +15,13 @@ from dataclasses import dataclass +from beartype.typing import Union import jax.numpy as jnp import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float +from gpjax.typing import ScalarFloat from ...base import param_field from ..base import AbstractKernel @@ -29,13 +32,13 @@ class RBF(AbstractKernel): """The Radial Basis Function (RBF) kernel.""" - lengthscale: Float[Array, "D"] = param_field( - jnp.array([1.0]), bijector=tfb.Softplus() + lengthscale: Union[ScalarFloat, Float[Array, "D"]] = param_field( + jnp.array(1.0), bijector=tfb.Softplus() ) - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) name: str = "RBF" - def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> ScalarFloat: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` @@ -48,7 +51,7 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, " y (Float[Array, "D"]): The right hand argument of the kernel function's call. Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. + ScalarFloat: The value of :math:`k(x, y)`. """ x = self.slice_input(x) / self.lengthscale y = self.slice_input(y) / self.lengthscale diff --git a/gpjax/kernels/stationary/utils.py b/gpjax/kernels/stationary/utils.py index 9ff9248a..68b842b4 100644 --- a/gpjax/kernels/stationary/utils.py +++ b/gpjax/kernels/stationary/utils.py @@ -15,7 +15,9 @@ import jax.numpy as jnp import tensorflow_probability.substrates.jax as tfp -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float +from gpjax.typing import ScalarFloat tfd = tfp.distributions @@ -33,7 +35,7 @@ def build_student_t_distribution(nu: int) -> tfd.Distribution: return dist -def squared_distance(x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: +def squared_distance(x: Float[Array, "D"], y: Float[Array, "D"]) -> ScalarFloat: """Compute the squared distance between a pair of inputs. Args: @@ -41,13 +43,13 @@ def squared_distance(x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, y (Float[Array, "D"]): Second input. Returns: - Float[Array, "1"]: The squared distance between the inputs. + ScalarFloat: The squared distance between the inputs. """ return jnp.sum((x - y) ** 2) -def euclidean_distance(x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: +def euclidean_distance(x: Float[Array, "D"], y: Float[Array, "D"]) -> ScalarFloat: """Compute the euclidean distance between a pair of inputs. Args: @@ -55,7 +57,7 @@ def euclidean_distance(x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Arra y (Float[Array, "D"]): Second input. Returns: - Float[Array, "1"]: The euclidean distance between the inputs. + ScalarFloat: The euclidean distance between the inputs. """ return jnp.sqrt(jnp.maximum(squared_distance(x, y), 1e-36)) diff --git a/gpjax/kernels/stationary/white.py b/gpjax/kernels/stationary/white.py index 4ea508b5..4da424d8 100644 --- a/gpjax/kernels/stationary/white.py +++ b/gpjax/kernels/stationary/white.py @@ -18,8 +18,10 @@ import jax.numpy as jnp import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float from simple_pytree import static_field +from gpjax.typing import ScalarFloat from ...base import param_field from ..base import AbstractKernel @@ -29,13 +31,13 @@ @dataclass class White(AbstractKernel): - variance: Float[Array, "1"] = param_field(jnp.array([1.0]), bijector=tfb.Softplus()) + variance: ScalarFloat = param_field(jnp.array(1.0), bijector=tfb.Softplus()) compute_engine: AbstractKernelComputation = static_field( ConstantDiagonalKernelComputation ) name: str = "White" - def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "1"]: + def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> ScalarFloat: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with variance :math:`\\sigma` .. math:: @@ -46,7 +48,7 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, " y (Float[Array, "D"]): The right hand argument of the kernel function's call. Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. + ScalarFloat: The value of :math:`k(x, y)`. """ K = jnp.all(jnp.equal(x, y)) * self.variance return K.squeeze() diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index b4f9e2d0..42977948 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -15,16 +15,18 @@ import abc from dataclasses import dataclass -from typing import Any +from beartype.typing import Any, Union import jax.numpy as jnp import jax.scipy as jsp import tensorflow_probability.substrates.jax as tfp -from jaxtyping import Array, Float +from gpjax.typing import Array, ScalarFloat +from jaxtyping import Float from simple_pytree import static_field from .base import Module, param_field from .linops.utils import to_dense +from .gaussian_distribution import GaussianDistribution tfb = tfp.bijectors tfd = tfp.distributions @@ -63,7 +65,7 @@ def predict(self, *args: Any, **kwargs: Any) -> tfd.Distribution: @property @abc.abstractmethod - def link_function(self) -> tfd.Distribution: + def link_function(self, f: Float[Array, "..."]) -> tfd.Distribution: """Return the link function of the likelihood function. Returns: @@ -76,16 +78,15 @@ def link_function(self) -> tfd.Distribution: class Gaussian(AbstractLikelihood): """Gaussian likelihood object.""" - obs_noise: Float[Array, "1"] = param_field( - jnp.array([1.0]), bijector=tfb.Softplus() + obs_noise: Union[ScalarFloat, Float[Array, "#N"]] = param_field( + jnp.array(1.0), bijector=tfb.Softplus() ) - def link_function(self, f: Float[Array, "N 1"]) -> tfd.Normal: + def link_function(self, f: Float[Array, "..."]) -> tfd.Normal: """The link function of the Gaussian likelihood. Args: - params (Dict): The parameters of the likelihood function. - f (Float[Array, "N 1"]): Function values. + f (Float[Array, "..."]): Function values. Returns: tfd.Normal: The likelihood function. @@ -93,7 +94,7 @@ def link_function(self, f: Float[Array, "N 1"]) -> tfd.Normal: return tfd.Normal(loc=f, scale=self.obs_noise.astype(f.dtype)) def predict( - self, dist: tfd.MultivariateNormalTriL + self, dist: Union[tfd.MultivariateNormalTriL, GaussianDistribution] ) -> tfd.MultivariateNormalFullCovariance: """ Evaluate the Gaussian likelihood function at a given predictive @@ -102,7 +103,6 @@ def predict( distribution's covariance matrix. Args: - params (Dict): The parameters of the likelihood function. dist (tfd.Distribution): The Gaussian process posterior, evaluated at a finite set of test points. @@ -118,11 +118,11 @@ def predict( @dataclass class Bernoulli(AbstractLikelihood): - def link_function(self, f: Float[Array, "N 1"]) -> tfd.Distribution: + def link_function(self, f: Float[Array, "..."]) -> tfd.Distribution: """The probit link function of the Bernoulli likelihood. Args: - f (Float[Array, "N 1"]): Function values. + f (Float[Array, "..."]): Function values. Returns: tfd.Distribution: The likelihood function. @@ -134,7 +134,6 @@ def predict(self, dist: tfd.Distribution) -> tfd.Distribution: process posterior and likelihood parameters. Args: - params (Dict): The parameters of the likelihood function. dist (tfd.Distribution): The Gaussian process posterior, evaluated at a finite set of test points. @@ -146,14 +145,14 @@ def predict(self, dist: tfd.Distribution) -> tfd.Distribution: return self.link_function(mean / jnp.sqrt(1.0 + variance)) -def inv_probit(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: +def inv_probit(x: Float[Array, "*N"]) -> Float[Array, "*N"]: """Compute the inverse probit function. Args: - x (Float[Array, "N 1"]): A vector of values. + x (Float[Array, "*N"]): A vector of values. Returns: - Float[Array, "N 1"]: The inverse probit of the input vector. + Float[Array, "*N"]: The inverse probit of the input vector. """ jitter = 1e-3 # To ensure output is in interval (0, 1). return 0.5 * (1.0 + jsp.special.erf(x / jnp.sqrt(2.0))) * (1 - 2 * jitter) + jitter diff --git a/gpjax/linops/constant_diagonal_linear_operator.py b/gpjax/linops/constant_diagonal_linear_operator.py index f525b885..458233ae 100644 --- a/gpjax/linops/constant_diagonal_linear_operator.py +++ b/gpjax/linops/constant_diagonal_linear_operator.py @@ -13,15 +13,17 @@ # limitations under the License. # ============================================================================== -from __future__ import annotations +from beartype.typing import Any, Union from dataclasses import dataclass -from typing import Any, Union import jax.numpy as jnp -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float from simple_pytree import static_field +from gpjax.typing import ScalarFloat +from .linear_operator import LinearOperator from .diagonal_linear_operator import DiagonalLinearOperator from .linear_operator import LinearOperator @@ -63,7 +65,7 @@ def __init__( def __add__( self, other: Union[Float[Array, "N N"], LinearOperator] - ) -> DiagonalLinearOperator: + ) -> LinearOperator: if isinstance(other, ConstantDiagonalLinearOperator): if other.size == self.size: return ConstantDiagonalLinearOperator( @@ -77,7 +79,7 @@ def __add__( else: return super().__add__(other) - def __mul__(self, other: float) -> LinearOperator: + def __mul__(self, other: Union[ScalarFloat, Float[Array, "1"]]) -> LinearOperator: """Multiply covariance operator by scalar. Args: @@ -116,7 +118,7 @@ def diagonal(self) -> Float[Array, "N"]: """Diagonal of the covariance operator.""" return self.value * jnp.ones(self.size) - def to_root(self) -> ConstantDiagonalLinearOperator: + def to_root(self) -> "ConstantDiagonalLinearOperator": """ Lower triangular. @@ -127,15 +129,15 @@ def to_root(self) -> ConstantDiagonalLinearOperator: value=jnp.sqrt(self.value), size=self.size ) - def log_det(self) -> Float[Array, "1"]: + def log_det(self) -> ScalarFloat: """Log determinant. Returns: - Float[Array, "1"]: Log determinant of the covariance matrix. + ScalarFloat: Log determinant of the covariance matrix. """ - return 2.0 * self.size * jnp.log(self.value) + return 2.0 * self.size * jnp.log(self.value.squeeze()) - def inverse(self) -> ConstantDiagonalLinearOperator: + def inverse(self) -> "ConstantDiagonalLinearOperator": """Inverse of the covariance operator. Returns: @@ -143,7 +145,7 @@ def inverse(self) -> ConstantDiagonalLinearOperator: """ return ConstantDiagonalLinearOperator(value=1.0 / self.value, size=self.size) - def solve(self, rhs: Float[Array, "N M"]) -> Float[Array, "N M"]: + def solve(self, rhs: Float[Array, "... M"]) -> Float[Array, "... M"]: """Solve linear system. Args: @@ -156,7 +158,7 @@ def solve(self, rhs: Float[Array, "N M"]) -> Float[Array, "N M"]: return rhs / self.value @classmethod - def from_dense(cls, dense: Float[Array, "N N"]) -> ConstantDiagonalLinearOperator: + def from_dense(cls, dense: Float[Array, "N N"]) -> "ConstantDiagonalLinearOperator": """Construct covariance operator from dense matrix. Args: @@ -171,8 +173,8 @@ def from_dense(cls, dense: Float[Array, "N N"]) -> ConstantDiagonalLinearOperato @classmethod def from_root( - cls, root: ConstantDiagonalLinearOperator - ) -> ConstantDiagonalLinearOperator: + cls, root: "ConstantDiagonalLinearOperator" + ) -> "ConstantDiagonalLinearOperator": """Construct covariance operator from root. Args: diff --git a/gpjax/linops/dense_linear_operator.py b/gpjax/linops/dense_linear_operator.py index ea648515..d985f3d6 100644 --- a/gpjax/linops/dense_linear_operator.py +++ b/gpjax/linops/dense_linear_operator.py @@ -13,21 +13,18 @@ # limitations under the License. # ============================================================================== -from __future__ import annotations -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from gpjax.linops.diagonal_linear_operator import DiagonalLinearOperator +from beartype.typing import Union from dataclasses import dataclass -from typing import Union import jax.numpy as jnp -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float from gpjax.linops.linear_operator import LinearOperator from gpjax.linops.utils import to_linear_operator +from gpjax.typing import ScalarFloat, VecNOrMatNM def _check_matrix(matrix: Array) -> None: @@ -95,7 +92,7 @@ def __add__( else: raise NotImplementedError - def __mul__(self, other: float) -> LinearOperator: + def __mul__(self, other: ScalarFloat) -> LinearOperator: """Multiply covariance operator by scalar. Args: @@ -107,7 +104,7 @@ def __mul__(self, other: float) -> LinearOperator: return DenseLinearOperator(matrix=self.matrix * other) - def _add_diagonal(self, other: DiagonalLinearOperator) -> LinearOperator: + def _add_diagonal(self, other: "gpjax.linops.diagonal_linear_operator.DiagonalLinearOperator") -> LinearOperator: """Add diagonal to the covariance operator, useful for computing, Kxx + Iσ². Args: @@ -131,7 +128,7 @@ def diagonal(self) -> Float[Array, "N"]: """ return jnp.diag(self.matrix) - def __matmul__(self, other: Float[Array, "N M"]) -> Float[Array, "N M"]: + def __matmul__(self, other: VecNOrMatNM) -> VecNOrMatNM: """Matrix multiplication. Args: @@ -152,7 +149,7 @@ def to_dense(self) -> Float[Array, "N N"]: return self.matrix @classmethod - def from_dense(cls, matrix: Float[Array, "N N"]) -> DenseLinearOperator: + def from_dense(cls, matrix: Float[Array, "N N"]) -> "DenseLinearOperator": """Construct covariance operator from dense covariance matrix. Args: @@ -164,7 +161,7 @@ def from_dense(cls, matrix: Float[Array, "N N"]) -> DenseLinearOperator: return DenseLinearOperator(matrix=matrix) @classmethod - def from_root(cls, root: LinearOperator) -> DenseLinearOperator: + def from_root(cls, root: LinearOperator) -> "DenseLinearOperator": """Construct covariance operator from the root of the covariance matrix. Args: diff --git a/gpjax/linops/diagonal_linear_operator.py b/gpjax/linops/diagonal_linear_operator.py index febf92d6..bc0a3136 100644 --- a/gpjax/linops/diagonal_linear_operator.py +++ b/gpjax/linops/diagonal_linear_operator.py @@ -13,17 +13,18 @@ # limitations under the License. # ============================================================================== -from __future__ import annotations from dataclasses import dataclass -from typing import Any, Union +from beartype.typing import Any, Union import jax.numpy as jnp -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float from .dense_linear_operator import DenseLinearOperator from .linear_operator import LinearOperator from .utils import to_linear_operator +from gpjax.typing import ScalarFloat, VecNOrMatNM def _check_diag(diag: Any) -> None: @@ -94,7 +95,7 @@ def __add__( else: raise NotImplementedError - def __mul__(self, other: float) -> LinearOperator: + def __mul__(self, other: ScalarFloat) -> LinearOperator: """Multiply covariance operator by scalar. Args: @@ -126,7 +127,7 @@ def to_dense(self) -> Float[Array, "N N"]: """ return jnp.diag(self.diagonal()) - def __matmul__(self, other: Float[Array, "N M"]) -> Float[Array, "N M"]: + def __matmul__(self, other: VecNOrMatNM) -> VecNOrMatNM: """Matrix multiplication. Args: @@ -141,7 +142,7 @@ def __matmul__(self, other: Float[Array, "N M"]) -> Float[Array, "N M"]: return diag * other - def to_root(self) -> DiagonalLinearOperator: + def to_root(self) -> "DiagonalLinearOperator": """ Lower triangular. @@ -150,15 +151,15 @@ def to_root(self) -> DiagonalLinearOperator: """ return DiagonalLinearOperator(diag=jnp.sqrt(self.diagonal())) - def log_det(self) -> Float[Array, "1"]: + def log_det(self) -> ScalarFloat: """Log determinant. Returns: - Float[Array, "1"]: Log determinant of the covariance matrix. + ScalarFloat: Log determinant of the covariance matrix. """ return jnp.sum(jnp.log(self.diagonal())) - def inverse(self) -> DiagonalLinearOperator: + def inverse(self) -> "DiagonalLinearOperator": """Inverse of the covariance operator. Returns: @@ -166,7 +167,7 @@ def inverse(self) -> DiagonalLinearOperator: """ return DiagonalLinearOperator(diag=1.0 / self.diagonal()) - def solve(self, rhs: Float[Array, "N M"]) -> Float[Array, "N M"]: + def solve(self, rhs: VecNOrMatNM) -> VecNOrMatNM: """Solve linear system. Args: @@ -179,7 +180,7 @@ def solve(self, rhs: Float[Array, "N M"]) -> Float[Array, "N M"]: return self.inverse() @ rhs @classmethod - def from_root(cls, root: DiagonalLinearOperator) -> DiagonalLinearOperator: + def from_root(cls, root: "DiagonalLinearOperator") -> "DiagonalLinearOperator": """Construct covariance operator from the lower triangular matrix. Returns: @@ -188,7 +189,7 @@ def from_root(cls, root: DiagonalLinearOperator) -> DiagonalLinearOperator: return DiagonalFromRootLinearOperator(root=root) @classmethod - def from_dense(cls, dense: Float[Array, "N N"]) -> DiagonalLinearOperator: + def from_dense(cls, dense: Float[Array, "N N"]) -> "DiagonalLinearOperator": """Construct covariance operator from its dense matrix representation. Returns: diff --git a/gpjax/linops/identity_linear_operator.py b/gpjax/linops/identity_linear_operator.py index 0b308205..c2ea16e7 100644 --- a/gpjax/linops/identity_linear_operator.py +++ b/gpjax/linops/identity_linear_operator.py @@ -13,13 +13,14 @@ # limitations under the License. # ============================================================================== -from __future__ import annotations +from beartype.typing import Any from dataclasses import dataclass -from typing import Any import jax.numpy as jnp -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float +from gpjax.typing import ScalarFloat from gpjax.linops.constant_diagonal_linear_operator import \ ConstantDiagonalLinearOperator @@ -64,7 +65,7 @@ def __matmul__(self, other: Float[Array, "N M"]) -> Float[Array, "N M"]: """ return other - def to_root(self) -> IdentityLinearOperator: + def to_root(self) -> "IdentityLinearOperator": """ Lower triangular. @@ -73,15 +74,15 @@ def to_root(self) -> IdentityLinearOperator: """ return self - def log_det(self) -> Float[Array, "1"]: + def log_det(self) -> ScalarFloat: """Log determinant. Returns: - Float[Array, "1"]: Log determinant of the covariance matrix. + ScalarFloat: Log determinant of the covariance matrix. """ return jnp.array(0.0) - def inverse(self) -> ConstantDiagonalLinearOperator: + def inverse(self) -> "IdentityLinearOperator": """Inverse of the covariance operator. Returns: @@ -89,7 +90,7 @@ def inverse(self) -> ConstantDiagonalLinearOperator: """ return self - def solve(self, rhs: Float[Array, "N M"]) -> Float[Array, "N M"]: + def solve(self, rhs: Float[Array, "... M"]) -> Float[Array, "... M"]: """Solve linear system. Args: @@ -103,7 +104,7 @@ def solve(self, rhs: Float[Array, "N M"]) -> Float[Array, "N M"]: return rhs @classmethod - def from_root(cls, root: IdentityLinearOperator) -> IdentityLinearOperator: + def from_root(cls, root: "IdentityLinearOperator") -> "IdentityLinearOperator": """Construct from root. Args: @@ -115,7 +116,7 @@ def from_root(cls, root: IdentityLinearOperator) -> IdentityLinearOperator: return root @classmethod - def from_dense(cls, dense: Float[Array, "N N"]) -> IdentityLinearOperator: + def from_dense(cls, dense: Float[Array, "N N"]) -> "IdentityLinearOperator": return IdentityLinearOperator(dense.shape[0]) diff --git a/gpjax/linops/linear_operator.py b/gpjax/linops/linear_operator.py index 66caf727..0abe1320 100644 --- a/gpjax/linops/linear_operator.py +++ b/gpjax/linops/linear_operator.py @@ -13,20 +13,16 @@ # limitations under the License. # ============================================================================== -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from gpjax.linops.diagonal_linear_operator import DiagonalLinearOperator import abc from dataclasses import dataclass -from typing import Any, Generic, Iterable, Mapping, Tuple, TypeVar, Union +from beartype.typing import Any, Generic, Iterable, Mapping, Tuple, Type, TypeVar, Union import jax.numpy as jnp -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float from simple_pytree import Pytree, static_field +from gpjax.typing import ScalarFloat # Generic type. T = TypeVar("T") @@ -35,8 +31,9 @@ NestedT = Union[T, Iterable["NestedT"], Mapping[Any, "NestedT"]] # Nested types. +DTypes = Union[Type[jnp.float32], Type[jnp.float64], Type[jnp.int32], Type[jnp.int64]] ShapeT = TypeVar("ShapeT", bound=NestedT[Tuple[int, ...]]) -DTypeT = TypeVar("DTypeT", bound=NestedT[jnp.dtype]) +DTypeT = TypeVar("DTypeT", bound=NestedT[DTypes]) # The Generic type is used for type checking the LinearOperator's shape and datatype. # `static_field` is used to mark nodes of the PyTree that don't change under JAX transformations. @@ -60,58 +57,58 @@ def ndim(self) -> int: return len(self.shape) @property - def T(self) -> LinearOperator: + def T(self) -> "LinearOperator": """Transpose linear operator. Currently, we assume all linear operators are square and symmetric.""" return self def __sub__( - self, other: Union[LinearOperator, Float[Array, "N N"]] - ) -> LinearOperator: + self, other: Union["LinearOperator", Float[Array, "N N"]] + ) -> "LinearOperator": """Subtract linear operator.""" return self + (other * -1) def __rsub__( - self, other: Union[LinearOperator, Float[Array, "N N"]] - ) -> LinearOperator: + self, other: Union["LinearOperator", Float[Array, "N N"]] + ) -> "LinearOperator": """Reimplimentation of subtract linear operator.""" return (self * -1) + other def __add__( - self, other: Union[LinearOperator, Float[Array, "N N"]] - ) -> LinearOperator: + self, other: Union["LinearOperator", Float[Array, "N N"]] + ) -> "LinearOperator": """Add linear operator.""" raise NotImplementedError def __radd__( - self, other: Union[LinearOperator, Float[Array, "N N"]] - ) -> LinearOperator: + self, other: Union["LinearOperator", Float[Array, "N N"]] + ) -> "LinearOperator": """Reimplimentation of add linear operator.""" return self + other @abc.abstractmethod - def __mul__(self, other: float) -> LinearOperator: + def __mul__(self, other: ScalarFloat) -> "LinearOperator": """Multiply linear operator by scalar.""" raise NotImplementedError - def __rmul__(self, other: float) -> LinearOperator: + def __rmul__(self, other: ScalarFloat) -> "LinearOperator": """Reimplimentation of multiply linear operator by scalar.""" return self.__mul__(other) @abc.abstractmethod - def _add_diagonal(self, other: DiagonalLinearOperator) -> LinearOperator: + def _add_diagonal(self, other: "gpjax.linops.diagonal_linear_operator.DiagonalLinearOperator") -> "LinearOperator": """Add diagonal linear operator to a linear operator, useful e.g., for adding jitter.""" return NotImplementedError @abc.abstractmethod def __matmul__( - self, other: Union[LinearOperator, Float[Array, "N M"]] - ) -> Union[LinearOperator, Float[Array, "N M"]]: + self, other: Union["LinearOperator", Float[Array, "N M"]] + ) -> Union["LinearOperator", Float[Array, "N M"]]: """Matrix multiplication.""" raise NotImplementedError def __rmatmul__( - self, other: Union[LinearOperator, Float[Array, "N M"]] - ) -> Union[LinearOperator, Float[Array, "N M"]]: + self, other: Union["LinearOperator", Float[Array, "N M"]] + ) -> Union["LinearOperator", Float[Array, "N M"]]: """Reimplimentation of matrix multiplication.""" # Exploit the fact that linear operators are square and symmetric. if other.ndim == 1: @@ -128,26 +125,26 @@ def diagonal(self) -> Float[Array, "N"]: raise NotImplementedError - def trace(self) -> Float[Array, "1"]: + def trace(self) -> ScalarFloat: """Trace of the linear matrix. Returns: - Float[Array, "1"]: Trace of the linear matrix. + ScalarFloat: Trace of the linear matrix. """ return jnp.sum(self.diagonal()) - def log_det(self) -> Float[Array, "1"]: + def log_det(self) -> ScalarFloat: """Log determinant of the linear matrix. Default implementation uses dense Cholesky decomposition. Returns: - Float[Array, "1"]: Log determinant of the linear matrix. + ScalarFloat: Log determinant of the linear matrix. """ root = self.to_root() return 2.0 * jnp.sum(jnp.log(root.diagonal())) - def to_root(self) -> LinearOperator: + def to_root(self) -> "LinearOperator": """Compute the root of the linear operator via the Cholesky decomposition. Returns: @@ -161,7 +158,7 @@ def to_root(self) -> LinearOperator: return LowerTriangularLinearOperator.from_dense(L) - def inverse(self) -> LinearOperator: + def inverse(self) -> "LinearOperator": """Inverse of the linear matrix. Default implementation uses dense Cholesky decomposition. Returns: @@ -174,7 +171,7 @@ def inverse(self) -> LinearOperator: return DenseLinearOperator(self.solve(jnp.eye(n))) - def solve(self, rhs: Float[Array, "N M"]) -> Float[Array, "N M"]: + def solve(self, rhs: Float[Array, "... M"]) -> Float[Array, "... M"]: """Solve linear system. Default implementation uses dense Cholesky decomposition. Args: @@ -199,7 +196,7 @@ def to_dense(self) -> Float[Array, "N N"]: raise NotImplementedError @classmethod - def from_dense(cls, dense: Float[Array, "N N"]) -> LinearOperator: + def from_dense(cls, dense: Float[Array, "N N"]) -> "LinearOperator": """Construct linear operator from dense matrix. Args: diff --git a/gpjax/linops/triangular_linear_operator.py b/gpjax/linops/triangular_linear_operator.py index 4fc9335a..000ee3c9 100644 --- a/gpjax/linops/triangular_linear_operator.py +++ b/gpjax/linops/triangular_linear_operator.py @@ -13,11 +13,11 @@ # limitations under the License. # ============================================================================== -from __future__ import annotations import jax.numpy as jnp import jax.scipy as jsp -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float from .dense_linear_operator import DenseLinearOperator from .linear_operator import LinearOperator @@ -29,7 +29,7 @@ class LowerTriangularLinearOperator(DenseLinearOperator): """ @property - def T(self) -> UpperTriangularLinearOperator: + def T(self) -> "UpperTriangularLinearOperator": return UpperTriangularLinearOperator(matrix=self.matrix.T) def to_root(self) -> LinearOperator: @@ -39,21 +39,15 @@ def inverse(self) -> DenseLinearOperator: matrix = self.solve(jnp.eye(self.size)) return DenseLinearOperator(matrix) - def solve(self, rhs: Float[Array, "N"]) -> Float[Array, "N"]: + def solve(self, rhs: Float[Array, "... M"]) -> Float[Array, "... M"]: return jsp.linalg.solve_triangular(self.to_dense(), rhs, lower=True) - def __matmul__(self, other): - return super().__matmul__(other) - - def __add__(self, other): - return super().__matmul__(other) - @classmethod def from_root(cls, root: LinearOperator) -> None: raise ValueError("LowerTriangularLinearOperator does not have a root.") @classmethod - def from_dense(cls, dense: Float[Array, "N N"]) -> LowerTriangularLinearOperator: + def from_dense(cls, dense: Float[Array, "N N"]) -> "LowerTriangularLinearOperator": return LowerTriangularLinearOperator(matrix=dense) @@ -73,13 +67,7 @@ def inverse(self) -> DenseLinearOperator: matrix = self.solve(jnp.eye(self.size)) return DenseLinearOperator(matrix) - def __matmul__(self, other): - return super().__matmul__(other) - - def __add__(self, other): - return super().__matmul__(other) - - def solve(self, rhs: Float[Array, "N"]) -> Float[Array, "N"]: + def solve(self, rhs: Float[Array, "... M"]) -> Float[Array, "... M"]: return jsp.linalg.solve_triangular(self.to_dense(), rhs, lower=False) @classmethod @@ -87,7 +75,7 @@ def from_root(cls, root: LinearOperator) -> None: raise ValueError("LowerTriangularLinearOperator does not have a root.") @classmethod - def from_dense(cls, dense: Float[Array, "N N"]) -> UpperTriangularLinearOperator: + def from_dense(cls, dense: Float[Array, "N N"]) -> "UpperTriangularLinearOperator": return UpperTriangularLinearOperator(matrix=dense) diff --git a/gpjax/linops/utils.py b/gpjax/linops/utils.py index 0f313994..b3c0d816 100644 --- a/gpjax/linops/utils.py +++ b/gpjax/linops/utils.py @@ -13,21 +13,18 @@ # limitations under the License. # ============================================================================== -from __future__ import annotations -from typing import TYPE_CHECKING, Tuple, Union - -if TYPE_CHECKING: - from gpjax.linops.identity_linear_operator import IdentityLinearOperator +from beartype.typing import Union, Tuple, Type import jax import jax.numpy as jnp -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float from gpjax.linops.linear_operator import LinearOperator -def identity(n: int) -> IdentityLinearOperator: +def identity(n: int) -> "gpjax.linops.identity_linear_operator.IdentityLinearOperator": """Identity matrix. Args: @@ -105,7 +102,7 @@ def check_shapes_match(shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> None ) -def default_dtype() -> jnp.dtype: +def default_dtype() -> Union[Type[jnp.float64], Type[jnp.float32]]: """Get the default dtype for the linear operator. Returns: diff --git a/gpjax/linops/zero_linear_operator.py b/gpjax/linops/zero_linear_operator.py index f811e210..d998704c 100644 --- a/gpjax/linops/zero_linear_operator.py +++ b/gpjax/linops/zero_linear_operator.py @@ -13,14 +13,15 @@ # limitations under the License. # ============================================================================== -from __future__ import annotations +from beartype.typing import Any, Tuple, Union from dataclasses import dataclass -from typing import Any, Tuple, Union import jax.numpy as jnp -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float +from gpjax.typing import ScalarFloat from .diagonal_linear_operator import DiagonalLinearOperator from .linear_operator import LinearOperator from .utils import check_shapes_match, default_dtype, to_linear_operator @@ -41,7 +42,7 @@ def _check_size(shape: Any) -> None: class ZeroLinearOperator(LinearOperator): """Zero linear operator.""" - def __init__(self, shape: Tuple[int], dtype: jnp.dtype = None) -> None: + def __init__(self, shape: Tuple[int, ...], dtype: jnp.dtype = None) -> None: _check_size(shape) if dtype is None: @@ -85,7 +86,7 @@ def _add_diagonal(self, other: DiagonalLinearOperator) -> DiagonalLinearOperator check_shapes_match(self.shape, other.shape) return other - def __mul__(self, other: float) -> ZeroLinearOperator: + def __mul__(self, other: ScalarFloat) -> "ZeroLinearOperator": """Multiply covariance operator by scalar. Args: @@ -99,7 +100,7 @@ def __mul__(self, other: float) -> ZeroLinearOperator: def __matmul__( self, other: Union[LinearOperator, Float[Array, "N M"]] - ) -> ZeroLinearOperator: + ) -> "ZeroLinearOperator": """Matrix multiplication. Args: @@ -119,7 +120,7 @@ def to_dense(self) -> Float[Array, "N N"]: """ return jnp.zeros(self.shape) - def to_root(self) -> ZeroLinearOperator: + def to_root(self) -> "ZeroLinearOperator": """ Root of the covariance operator. @@ -128,11 +129,11 @@ def to_root(self) -> ZeroLinearOperator: """ return self - def log_det(self) -> Float[Array, "1"]: + def log_det(self) -> ScalarFloat: """Log determinant. Returns: - Float[Array, "1"]: Log determinant of the covariance matrix. + ScalarFloat: Log determinant of the covariance matrix. """ return jnp.log(jnp.array(0.0)) @@ -144,7 +145,7 @@ def inverse(self) -> None: """ raise RuntimeError("ZeroLinearOperator is not invertible.") - def solve(self, rhs: Float[Array, "N M"]) -> None: + def solve(self, rhs: Float[Array, "... M"]) -> None: """Solve linear system. Raises: @@ -153,7 +154,7 @@ def solve(self, rhs: Float[Array, "N M"]) -> None: raise RuntimeError("ZeroLinearOperator is not invertible.") @classmethod - def from_root(cls, root: ZeroLinearOperator) -> ZeroLinearOperator: + def from_root(cls, root: "ZeroLinearOperator") -> "ZeroLinearOperator": """Construct covariance operator from the root. Args: @@ -165,7 +166,7 @@ def from_root(cls, root: ZeroLinearOperator) -> ZeroLinearOperator: return root @classmethod - def from_dense(cls, dense: Float[Array, "N N"]) -> ZeroLinearOperator: + def from_dense(cls, dense: Float[Array, "N N"]) -> "ZeroLinearOperator": """Construct covariance operator from the dense matrix. Args: diff --git a/gpjax/mean_functions.py b/gpjax/mean_functions.py index 7bdafec0..e4f1a188 100644 --- a/gpjax/mean_functions.py +++ b/gpjax/mean_functions.py @@ -13,15 +13,15 @@ # limitations under the License. # ============================================================================== -from __future__ import annotations import abc import dataclasses from functools import partial -from typing import Callable, List, Union +from beartype.typing import Callable, List, Union import jax.numpy as jnp -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float, Num from simple_pytree import static_field from .base import Module, param_field @@ -32,7 +32,7 @@ class AbstractMeanFunction(Module): """Mean function that is used to parameterise the Gaussian process.""" @abc.abstractmethod - def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]: + def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N 1"]: """Evaluate the mean function at the given points. This method is required for all subclasses. Args: @@ -44,8 +44,8 @@ def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]: raise NotImplementedError def __add__( - self, other: Union[AbstractMeanFunction, Float[Array, "1"]] - ) -> AbstractMeanFunction: + self, other: Union["AbstractMeanFunction", Float[Array, "1"]] + ) -> "AbstractMeanFunction": """Add two mean functions. Args: @@ -61,8 +61,8 @@ def __add__( return SumMeanFunction([self, Constant(other)]) def __radd__( - self, other: Union[AbstractMeanFunction, Float[Array, "1"]] - ) -> AbstractMeanFunction: + self, other: Union["AbstractMeanFunction", Float[Array, "1"]] # TODO should this be ScalarFloat? or Num? + ) -> "AbstractMeanFunction": """Add two mean functions. Args: @@ -74,8 +74,8 @@ def __radd__( return self.__add__(other) def __mul__( - self, other: Union[AbstractMeanFunction, Float[Array, "1"]] - ) -> AbstractMeanFunction: + self, other: Union["AbstractMeanFunction", Float[Array, "1"]] # TODO should this be ScalarFloat? or Num? + ) -> "AbstractMeanFunction": """Multiply two mean functions. Args: @@ -90,8 +90,8 @@ def __mul__( return ProductMeanFunction([self, Constant(other)]) def __rmul__( - self, other: Union[AbstractMeanFunction, Float[Array, "1"]] - ) -> AbstractMeanFunction: + self, other: Union["AbstractMeanFunction", Float[Array, "1"]] # TODO should this be ScalarFloat? or Num? + ) -> "AbstractMeanFunction": """Multiply two mean functions. Args: @@ -107,17 +107,13 @@ def __rmul__( class Constant(AbstractMeanFunction): """ A constant mean function. This function returns a repeated scalar value for all inputs. -<<<<<<< HEAD - The scalar value itself can be treated as a model hyperparameter and learned during training. -======= The scalar value itself can be treated as a model hyperparameter and learned during training but defaults to 1.0. ->>>>>>> origin/rff_sampler """ constant: Float[Array, "1"] = param_field(jnp.array([0.0])) - def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]: + def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N 1"]: """Evaluate the mean function at the given points. Args: @@ -161,7 +157,7 @@ def __init__( self.means = items_list self.operator = operator - def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]: + def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N 1"]: """Evaluate combination kernel on a pair of inputs. Args: diff --git a/gpjax/objectives.py b/gpjax/objectives.py index 309000bd..0537d9c6 100644 --- a/gpjax/objectives.py +++ b/gpjax/objectives.py @@ -1,10 +1,3 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from .gps import ConjugatePosterior, NonConjugatePosterior - from .variational_families import AbstractVariationalFamily from abc import abstractmethod from dataclasses import dataclass @@ -13,10 +6,12 @@ import jax.scipy as jsp import jax.tree_util as jtu from jax import vmap -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float from simple_pytree import static_field import tensorflow_probability.substrates.jax as tfp +from gpjax.typing import ScalarFloat from .base import Module from .dataset import Dataset from .gaussian_distribution import GaussianDistribution @@ -37,19 +32,19 @@ def __post_init__(self) -> None: def __hash__(self): return hash(tuple(jtu.tree_leaves(self))) # Probably put this on the Module! - def __call__(self, *args, **kwargs) -> Float[Array, "1"]: + def __call__(self, *args, **kwargs) -> ScalarFloat: return self.step(*args, **kwargs) @abstractmethod - def step(self, *args, **kwargs) -> Float[Array, "1"]: + def step(self, *args, **kwargs) -> ScalarFloat: raise NotImplementedError class ConjugateMLL(AbstractObjective): def step( - self, posterior: ConjugatePosterior, train_data: Dataset - ) -> Float[Array, "1"]: + self, posterior: "gpjax.gps.ConjugatePosterior", train_data: Dataset + ) -> ScalarFloat: """Compute the marginal log-likelihood function of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation @@ -107,7 +102,7 @@ def step( Defaults to False. Returns: - Callable[[Parameters], Float[Array, "1"]]: A functional representation + Callable[[Parameters], ScalarFloat]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set. """ @@ -131,8 +126,8 @@ def step( class NonConjugateMLL(AbstractObjective): def step( - self, posterior: NonConjugatePosterior, data: Dataset - ) -> Float[Array, "1"]: + self, posterior: "gpjax.gps.NonConjugatePosterior", data: Dataset + ) -> ScalarFloat: """ Compute the marginal log-likelihood function of the Gaussian process. The returned function can then be used for gradient based optimisation @@ -158,7 +153,7 @@ def step( to maximisation of the marginal log-likelihood. Defaults to False. Returns: - Callable[[Parameters], Float[Array, "1"]]: A functional representation + Callable[[Parameters], ScalarFloat]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set. """ @@ -190,8 +185,8 @@ def step( class ELBO(AbstractObjective): def step( - self, variational_family: AbstractVariationalFamily, train_data: Dataset - ) -> Float[Array, "1"]: + self, variational_family: "gpjax.variational_families.AbstractVariationalFamily", train_data: Dataset + ) -> ScalarFloat: """Compute the evidence lower bound under this model. In short, this requires evaluating the expectation of the model's log-likelihood under the variational approximation. To this, we sum the KL divergence from the variational posterior @@ -233,9 +228,9 @@ def step( def variational_expectation( - variational_family: AbstractVariationalFamily, + variational_family: "gpjax.variational_families.AbstractVariationalFamily", train_data: Dataset, -) -> Float[Array, "N 1"]: +) -> Float[Array, "N"]: """Compute the expectation of our model's log-likelihood under our variational distribution. Batching can be done here to speed up computation. @@ -279,8 +274,8 @@ class CollapsedELBO(AbstractObjective): """ def step( - self, variational_family: AbstractVariationalFamily, train_data: Dataset - ) -> Float[Array, "1"]: + self, variational_family: "gpjax.variational_families.AbstractVariationalFamily", train_data: Dataset + ) -> ScalarFloat: """Compute the evidence lower bound under this model. In short, this requires evaluating the expectation of the model's log-likelihood under the variational approximation. To this, we sum the KL divergence from the variational posterior diff --git a/gpjax/progress_bar.py b/gpjax/progress_bar.py index 5a10ef5e..19c84ff4 100644 --- a/gpjax/progress_bar.py +++ b/gpjax/progress_bar.py @@ -13,11 +13,12 @@ # limitations under the License. # ============================================================================== -from typing import Any, Callable, Union +from beartype.typing import Any, Callable, Union from jax import lax from jax.experimental import host_callback -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float from tqdm.auto import tqdm diff --git a/gpjax/quadrature.py b/gpjax/quadrature.py index 908fdde4..acd81c6d 100644 --- a/gpjax/quadrature.py +++ b/gpjax/quadrature.py @@ -13,11 +13,12 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Optional +from beartype.typing import Callable, Optional import jax.numpy as jnp import numpy as np -from jaxtyping import Array, Float +from gpjax.typing import Array +from jaxtyping import Float """The number of Gauss-Hermite points to use for quadrature""" DEFAULT_NUM_GAUSS_HERMITE_POINTS = 20 diff --git a/gpjax/scan.py b/gpjax/scan.py index 9eeb7c10..ad534013 100644 --- a/gpjax/scan.py +++ b/gpjax/scan.py @@ -13,21 +13,23 @@ # limitations under the License. # ============================================================================== -from typing import Any, Callable, List, Optional, Tuple, TypeVar +from beartype.typing import Any, Callable, List, Optional, Tuple, TypeVar import jax import jax.numpy as jnp import jax.tree_util as jtu from jax import lax +from jaxtyping import Shaped, Array from jax.experimental import host_callback as hcb from tqdm.auto import trange +from gpjax.typing import ScalarInt, ScalarBool Carry = TypeVar("Carry") X = TypeVar("X") Y = TypeVar("Y") -def _callback(cond: bool, func: Callable, *args: Any) -> None: +def _callback(cond: ScalarBool, func: Callable, *args: Any) -> None: """Callback a function for a given argument if a condition is true. Args: @@ -59,7 +61,7 @@ def vscan( unroll: Optional[int] = 1, log_rate: Optional[int] = 10, log_value: Optional[bool] = True, -) -> Tuple[Carry, List[Y]]: +) -> Tuple[Carry, Shaped[Array, "..."]]: # return type should be Tuple[Carry, Y[Array]]... """Scan with verbose output. This is based on code from the excellent blog post: @@ -112,7 +114,7 @@ def _close_tqdm(args: Any, transform: Any) -> None: """Close the tqdm progress bar.""" _progress_bar.close() - def _body_fun(carry: Carry, iter_num_and_x: Tuple[int, X]) -> Tuple[Carry, Y]: + def _body_fun(carry: Carry, iter_num_and_x: Tuple[ScalarInt, X]) -> Tuple[Carry, Y]: # Unpack iter_num and x. iter_num, x = iter_num_and_x diff --git a/gpjax/typing.py b/gpjax/typing.py new file mode 100644 index 00000000..08d56808 --- /dev/null +++ b/gpjax/typing.py @@ -0,0 +1,38 @@ +# Copyright 2023 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from beartype.typing import Union +from jaxtyping import Bool, UInt32, Int, Float +from numpy import ndarray as NumpyArray +from jaxtyping import Array as JAXArray +from jax.random import KeyArray as JAXKeyArray + +OldKeyArray = UInt32[JAXArray, "2"] +KeyArray = Union[OldKeyArray, JAXKeyArray] # for compatibility regardless of enable_custom_prng setting + +Array = Union[JAXArray, NumpyArray] + +ScalarBool = Union[bool, Bool[Array, ""]] +ScalarInt = Union[int, Int[Array, ""]] +ScalarFloat = Union[float, Float[Array, ""]] + +VecNOrMatNM = Union[Float[Array, "N"], Float[Array, "N M"]] + +__all__ = [ + "KeyArray", + "ScalarBool", + "ScalarInt", + "ScalarFloat", +] diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 51e18f6e..bfdc1750 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -14,13 +14,14 @@ # ============================================================================== import abc +from beartype.typing import Any from dataclasses import dataclass -from typing import Any import jax.numpy as jnp import jax.scipy as jsp import tensorflow_probability.substrates.jax.bijectors as tfb -from jaxtyping import Array, Float +from gpjax.typing import Array, ScalarFloat +from jaxtyping import Float from simple_pytree import static_field from .base import Module, param_field @@ -30,6 +31,7 @@ from .likelihoods import Gaussian from .linops import (DenseLinearOperator, LowerTriangularLinearOperator, identity) +from gpjax.typing import ScalarFloat @dataclass @@ -76,7 +78,7 @@ class AbstractVariationalGaussian(AbstractVariationalFamily): """The variational Gaussian family of probability distributions.""" inducing_inputs: Float[Array, "N D"] - jitter: Float[Array, "1"] = static_field(1e-6) + jitter: ScalarFloat = static_field(1e-6) @property def num_inducing(self) -> int: @@ -107,7 +109,7 @@ def __post_init__(self) -> None: if self.variational_root_covariance is None: self.variational_root_covariance = jnp.eye(self.num_inducing) - def prior_kl(self) -> Float[Array, "1"]: + def prior_kl(self) -> ScalarFloat: """ Compute the KL-divergence between our variational approximation and the Gaussian process prior. @@ -117,7 +119,7 @@ def prior_kl(self) -> Float[Array, "1"]: inputs. Returns: - Float[Array, "1"]: The KL-divergence between our variational + ScalarFloat: The KL-divergence between our variational approximation and the GP prior. """ @@ -217,14 +219,14 @@ class WhitenedVariationalGaussian(VariationalGaussian): over μ and sqrt with S = sqrt sqrtᵀ. """ - def prior_kl(self) -> Float[Array, "1"]: + def prior_kl(self) -> ScalarFloat: """Compute the KL-divergence between our variational approximation and the Gaussian process prior. For this variational family, we have KL[q(f(·))||p(·)] = KL[q(u)||p(u)] = KL[N(μ, S)||N(0, I)]. Returns: - Float[Array, "1"]: The KL-divergence between our variational + ScalarFloat: The KL-divergence between our variational approximation and the GP prior. """ @@ -317,7 +319,7 @@ def __post_init__(self): if self.natural_matrix is None: self.natural_matrix = -0.5 * jnp.eye(self.num_inducing) - def prior_kl(self) -> Float[Array, "1"]: + def prior_kl(self) -> ScalarFloat: """Compute the KL-divergence between our current variational approximation and the Gaussian process prior. For this variational family, we have KL[q(f(·))||p(·)] = KL[q(u)||p(u)] = KL[N(μ, S)||N(mz, Kzz)], @@ -325,7 +327,7 @@ def prior_kl(self) -> Float[Array, "1"]: with μ and S computed from the natural parameterisation θ = (S⁻¹μ, -S⁻¹/2). Returns: - Float[Array, "1"]: The KL-divergence between our variational approximation and the GP prior. + ScalarFloat: The KL-divergence between our variational approximation and the GP prior. """ # Unpack variational parameters @@ -464,18 +466,15 @@ def __post_init__(self): if self.expectation_matrix is None: self.expectation_matrix = jnp.eye(self.num_inducing) - def prior_kl(self) -> Float[Array, "1"]: + def prior_kl(self) -> ScalarFloat: """Compute the KL-divergence between our current variational approximation and the Gaussian process prior. For this variational family, we have KL[q(f(·))||p(·)] = KL[q(u)||p(u)] = KL[N(μ, S)||N(mz, Kzz)], with μ and S computed from the expectation parameterisation η = (μ, S + uuᵀ). - Args: - params (Dict): The parameters at which our variational distribution and GP prior are to be evaluated. - Returns: - Float[Array, "1"]: The KL-divergence between our variational approximation and the GP prior. + ScalarFloat: The KL-divergence between our variational approximation and the GP prior. """ # Unpack variational parameters diff --git a/poetry.lock b/poetry.lock index 182abdff..7a56cc56 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.3.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. [[package]] name = "absl-py" @@ -221,14 +221,14 @@ test = ["coverage", "flake8", "mypy", "pexpect", "wheel"] [[package]] name = "astroid" -version = "2.15.3" +version = "2.15.4" description = "An abstract syntax tree for Python with inference support." category = "dev" optional = false python-versions = ">=3.7.2" files = [ - {file = "astroid-2.15.3-py3-none-any.whl", hash = "sha256:f11e74658da0f2a14a8d19776a8647900870a63de71db83713a8e77a6af52662"}, - {file = "astroid-2.15.3.tar.gz", hash = "sha256:44224ad27c54d770233751315fa7f74c46fa3ee0fab7beef1065f99f09897efe"}, + {file = "astroid-2.15.4-py3-none-any.whl", hash = "sha256:a1b8543ef9d36ea777194bc9b17f5f8678d2c56ee6a45b2c2f17eec96f242347"}, + {file = "astroid-2.15.4.tar.gz", hash = "sha256:c81e1c7fbac615037744d067a9bb5f9aeb655edf59b63ee8b59585475d6f80d8"}, ] [package.dependencies] @@ -315,6 +315,25 @@ files = [ {file = "backcall-0.2.0.tar.gz", hash = "sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e"}, ] +[[package]] +name = "beartype" +version = "0.13.1" +description = "Unbearably fast runtime type checking in pure Python." +category = "main" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "beartype-0.13.1-py3-none-any.whl", hash = "sha256:c3097b487e57bc278f1b55da8863b704b2a786c46483a6d3df39ab6fe2523d80"}, + {file = "beartype-0.13.1.tar.gz", hash = "sha256:2903947a8a1eb6030264e30108aa72cb1a805cfc9050c0f4014c4aed3a17a00b"}, +] + +[package.extras] +all = ["typing-extensions (>=3.10.0.0)"] +dev = ["autoapi (>=0.9.0)", "coverage (>=5.5)", "mypy (>=0.800)", "numpy", "pandera", "pydata-sphinx-theme (<=0.7.2)", "pytest (>=4.0.0)", "sphinx", "sphinx (>=4.2.0,<6.0.0)", "sphinxext-opengraph (>=0.7.5)", "tox (>=3.20.1)", "typing-extensions"] +doc-rtd = ["autoapi (>=0.9.0)", "pydata-sphinx-theme (<=0.7.2)", "sphinx (>=4.2.0,<6.0.0)", "sphinxext-opengraph (>=0.7.5)"] +test-tox = ["mypy (>=0.800)", "numpy", "pandera", "pytest (>=4.0.0)", "sphinx", "typing-extensions"] +test-tox-coverage = ["coverage (>=5.5)"] + [[package]] name = "beautifulsoup4" version = "4.12.2" @@ -1427,14 +1446,14 @@ files = [ [[package]] name = "importlib-metadata" -version = "6.5.0" +version = "6.6.0" description = "Read metadata from Python packages" category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "importlib_metadata-6.5.0-py3-none-any.whl", hash = "sha256:03ba783c3a2c69d751b109fc0c94a62c51f581b3d6acf8ed1331b6d5729321ff"}, - {file = "importlib_metadata-6.5.0.tar.gz", hash = "sha256:7a8bdf1bc3a726297f5cfbc999e6e7ff6b4fa41b26bba4afc580448624460045"}, + {file = "importlib_metadata-6.6.0-py3-none-any.whl", hash = "sha256:43dd286a2cd8995d5eaef7fee2066340423b818ed3fd70adf0bad5f1fac53fed"}, + {file = "importlib_metadata-6.6.0.tar.gz", hash = "sha256:92501cdf9cc66ebd3e612f1b4f0c0765dfa42f0fa38ffb319b6bd84dd675d705"}, ] [package.dependencies] @@ -2482,14 +2501,14 @@ testing-docutils = ["pygments", "pytest (>=7,<8)", "pytest-param-files (>=0.3.4, [[package]] name = "nbclient" -version = "0.7.3" +version = "0.7.4" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." category = "dev" optional = false python-versions = ">=3.7.0" files = [ - {file = "nbclient-0.7.3-py3-none-any.whl", hash = "sha256:8fa96f7e36693d5e83408f5e840f113c14a45c279befe609904dbe05dad646d1"}, - {file = "nbclient-0.7.3.tar.gz", hash = "sha256:26e41c6dca4d76701988bc34f64e1bfc2413ae6d368f13d7b5ac407efb08c755"}, + {file = "nbclient-0.7.4-py3-none-any.whl", hash = "sha256:c817c0768c5ff0d60e468e017613e6eae27b6fa31e43f905addd2d24df60c125"}, + {file = "nbclient-0.7.4.tar.gz", hash = "sha256:d447f0e5a4cfe79d462459aec1b3dc5c2e9152597262be8ee27f7d4c02566a0d"}, ] [package.dependencies] @@ -2638,40 +2657,40 @@ tox-to-nox = ["jinja2", "tox"] [[package]] name = "numpy" -version = "1.24.2" +version = "1.24.3" description = "Fundamental package for array computing in Python" category = "main" optional = false python-versions = ">=3.8" files = [ - {file = "numpy-1.24.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eef70b4fc1e872ebddc38cddacc87c19a3709c0e3e5d20bf3954c147b1dd941d"}, - {file = "numpy-1.24.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e8d2859428712785e8a8b7d2b3ef0a1d1565892367b32f915c4a4df44d0e64f5"}, - {file = "numpy-1.24.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6524630f71631be2dabe0c541e7675db82651eb998496bbe16bc4f77f0772253"}, - {file = "numpy-1.24.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a51725a815a6188c662fb66fb32077709a9ca38053f0274640293a14fdd22978"}, - {file = "numpy-1.24.2-cp310-cp310-win32.whl", hash = "sha256:2620e8592136e073bd12ee4536149380695fbe9ebeae845b81237f986479ffc9"}, - {file = "numpy-1.24.2-cp310-cp310-win_amd64.whl", hash = "sha256:97cf27e51fa078078c649a51d7ade3c92d9e709ba2bfb97493007103c741f1d0"}, - {file = "numpy-1.24.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7de8fdde0003f4294655aa5d5f0a89c26b9f22c0a58790c38fae1ed392d44a5a"}, - {file = "numpy-1.24.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4173bde9fa2a005c2c6e2ea8ac1618e2ed2c1c6ec8a7657237854d42094123a0"}, - {file = "numpy-1.24.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4cecaed30dc14123020f77b03601559fff3e6cd0c048f8b5289f4eeabb0eb281"}, - {file = "numpy-1.24.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a23f8440561a633204a67fb44617ce2a299beecf3295f0d13c495518908e910"}, - {file = "numpy-1.24.2-cp311-cp311-win32.whl", hash = "sha256:e428c4fbfa085f947b536706a2fc349245d7baa8334f0c5723c56a10595f9b95"}, - {file = "numpy-1.24.2-cp311-cp311-win_amd64.whl", hash = "sha256:557d42778a6869c2162deb40ad82612645e21d79e11c1dc62c6e82a2220ffb04"}, - {file = "numpy-1.24.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d0a2db9d20117bf523dde15858398e7c0858aadca7c0f088ac0d6edd360e9ad2"}, - {file = "numpy-1.24.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c72a6b2f4af1adfe193f7beb91ddf708ff867a3f977ef2ec53c0ffb8283ab9f5"}, - {file = "numpy-1.24.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c29e6bd0ec49a44d7690ecb623a8eac5ab8a923bce0bea6293953992edf3a76a"}, - {file = "numpy-1.24.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2eabd64ddb96a1239791da78fa5f4e1693ae2dadc82a76bc76a14cbb2b966e96"}, - {file = "numpy-1.24.2-cp38-cp38-win32.whl", hash = "sha256:e3ab5d32784e843fc0dd3ab6dcafc67ef806e6b6828dc6af2f689be0eb4d781d"}, - {file = "numpy-1.24.2-cp38-cp38-win_amd64.whl", hash = "sha256:76807b4063f0002c8532cfeac47a3068a69561e9c8715efdad3c642eb27c0756"}, - {file = "numpy-1.24.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4199e7cfc307a778f72d293372736223e39ec9ac096ff0a2e64853b866a8e18a"}, - {file = "numpy-1.24.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:adbdce121896fd3a17a77ab0b0b5eedf05a9834a18699db6829a64e1dfccca7f"}, - {file = "numpy-1.24.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:889b2cc88b837d86eda1b17008ebeb679d82875022200c6e8e4ce6cf549b7acb"}, - {file = "numpy-1.24.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f64bb98ac59b3ea3bf74b02f13836eb2e24e48e0ab0145bbda646295769bd780"}, - {file = "numpy-1.24.2-cp39-cp39-win32.whl", hash = "sha256:63e45511ee4d9d976637d11e6c9864eae50e12dc9598f531c035265991910468"}, - {file = "numpy-1.24.2-cp39-cp39-win_amd64.whl", hash = "sha256:a77d3e1163a7770164404607b7ba3967fb49b24782a6ef85d9b5f54126cc39e5"}, - {file = "numpy-1.24.2-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:92011118955724465fb6853def593cf397b4a1367495e0b59a7e69d40c4eb71d"}, - {file = "numpy-1.24.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9006288bcf4895917d02583cf3411f98631275bc67cce355a7f39f8c14338fa"}, - {file = "numpy-1.24.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:150947adbdfeceec4e5926d956a06865c1c690f2fd902efede4ca6fe2e657c3f"}, - {file = "numpy-1.24.2.tar.gz", hash = "sha256:003a9f530e880cb2cd177cba1af7220b9aa42def9c4afc2a2fc3ee6be7eb2b22"}, + {file = "numpy-1.24.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3c1104d3c036fb81ab923f507536daedc718d0ad5a8707c6061cdfd6d184e570"}, + {file = "numpy-1.24.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:202de8f38fc4a45a3eea4b63e2f376e5f2dc64ef0fa692838e31a808520efaf7"}, + {file = "numpy-1.24.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8535303847b89aa6b0f00aa1dc62867b5a32923e4d1681a35b5eef2d9591a463"}, + {file = "numpy-1.24.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d926b52ba1367f9acb76b0df6ed21f0b16a1ad87c6720a1121674e5cf63e2b6"}, + {file = "numpy-1.24.3-cp310-cp310-win32.whl", hash = "sha256:f21c442fdd2805e91799fbe044a7b999b8571bb0ab0f7850d0cb9641a687092b"}, + {file = "numpy-1.24.3-cp310-cp310-win_amd64.whl", hash = "sha256:ab5f23af8c16022663a652d3b25dcdc272ac3f83c3af4c02eb8b824e6b3ab9d7"}, + {file = "numpy-1.24.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9a7721ec204d3a237225db3e194c25268faf92e19338a35f3a224469cb6039a3"}, + {file = "numpy-1.24.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d6cc757de514c00b24ae8cf5c876af2a7c3df189028d68c0cb4eaa9cd5afc2bf"}, + {file = "numpy-1.24.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76e3f4e85fc5d4fd311f6e9b794d0c00e7002ec122be271f2019d63376f1d385"}, + {file = "numpy-1.24.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a1d3c026f57ceaad42f8231305d4653d5f05dc6332a730ae5c0bea3513de0950"}, + {file = "numpy-1.24.3-cp311-cp311-win32.whl", hash = "sha256:c91c4afd8abc3908e00a44b2672718905b8611503f7ff87390cc0ac3423fb096"}, + {file = "numpy-1.24.3-cp311-cp311-win_amd64.whl", hash = "sha256:5342cf6aad47943286afa6f1609cad9b4266a05e7f2ec408e2cf7aea7ff69d80"}, + {file = "numpy-1.24.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7776ea65423ca6a15255ba1872d82d207bd1e09f6d0894ee4a64678dd2204078"}, + {file = "numpy-1.24.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ae8d0be48d1b6ed82588934aaaa179875e7dc4f3d84da18d7eae6eb3f06c242c"}, + {file = "numpy-1.24.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ecde0f8adef7dfdec993fd54b0f78183051b6580f606111a6d789cd14c61ea0c"}, + {file = "numpy-1.24.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4749e053a29364d3452c034827102ee100986903263e89884922ef01a0a6fd2f"}, + {file = "numpy-1.24.3-cp38-cp38-win32.whl", hash = "sha256:d933fabd8f6a319e8530d0de4fcc2e6a61917e0b0c271fded460032db42a0fe4"}, + {file = "numpy-1.24.3-cp38-cp38-win_amd64.whl", hash = "sha256:56e48aec79ae238f6e4395886b5eaed058abb7231fb3361ddd7bfdf4eed54289"}, + {file = "numpy-1.24.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4719d5aefb5189f50887773699eaf94e7d1e02bf36c1a9d353d9f46703758ca4"}, + {file = "numpy-1.24.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0ec87a7084caa559c36e0a2309e4ecb1baa03b687201d0a847c8b0ed476a7187"}, + {file = "numpy-1.24.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea8282b9bcfe2b5e7d491d0bf7f3e2da29700cec05b49e64d6246923329f2b02"}, + {file = "numpy-1.24.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:210461d87fb02a84ef243cac5e814aad2b7f4be953b32cb53327bb49fd77fbb4"}, + {file = "numpy-1.24.3-cp39-cp39-win32.whl", hash = "sha256:784c6da1a07818491b0ffd63c6bbe5a33deaa0e25a20e1b3ea20cf0e43f8046c"}, + {file = "numpy-1.24.3-cp39-cp39-win_amd64.whl", hash = "sha256:d5036197ecae68d7f491fcdb4df90082b0d4960ca6599ba2659957aafced7c17"}, + {file = "numpy-1.24.3-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:352ee00c7f8387b44d19f4cada524586f07379c0d49270f87233983bc5087ca0"}, + {file = "numpy-1.24.3-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a7d6acc2e7524c9955e5c903160aa4ea083736fde7e91276b0e5d98e6332812"}, + {file = "numpy-1.24.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:35400e6a8d102fd07c71ed7dcadd9eb62ee9a6e84ec159bd48c28235bbb0f8e4"}, + {file = "numpy-1.24.3.tar.gz", hash = "sha256:ab344f1bf21f140adab8e47fdbc7c35a477dc01408791f8ba00d018dd0bc5155"}, ] [[package]] @@ -2988,19 +3007,19 @@ dev = ["black", "flake8", "mypy", "pytest", "responses", "setuptools", "types-re [[package]] name = "platformdirs" -version = "3.2.0" +version = "3.3.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "platformdirs-3.2.0-py3-none-any.whl", hash = "sha256:ebe11c0d7a805086e99506aa331612429a72ca7cd52a1f0d277dc4adc20cb10e"}, - {file = "platformdirs-3.2.0.tar.gz", hash = "sha256:d5b638ca397f25f979350ff789db335903d7ea010ab28903f57b27e1b16c2b08"}, + {file = "platformdirs-3.3.0-py3-none-any.whl", hash = "sha256:ea61fd7b85554beecbbd3e9b37fb26689b227ffae38f73353cbcc1cf8bd01878"}, + {file = "platformdirs-3.3.0.tar.gz", hash = "sha256:64370d47dc3fca65b4879f89bdead8197e93e05d696d6d1816243ebae8595da5"}, ] [package.extras] -docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"] -test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.2.2)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"] +docs = ["furo (>=2023.3.27)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"] [[package]] name = "pluggy" @@ -3233,18 +3252,18 @@ plugins = ["importlib-metadata"] [[package]] name = "pylint" -version = "2.17.2" +version = "2.17.3" description = "python code static checker" category = "dev" optional = false python-versions = ">=3.7.2" files = [ - {file = "pylint-2.17.2-py3-none-any.whl", hash = "sha256:001cc91366a7df2970941d7e6bbefcbf98694e00102c1f121c531a814ddc2ea8"}, - {file = "pylint-2.17.2.tar.gz", hash = "sha256:1b647da5249e7c279118f657ca28b6aaebb299f86bf92affc632acf199f7adbb"}, + {file = "pylint-2.17.3-py3-none-any.whl", hash = "sha256:a6cbb4c6e96eab4a3c7de7c6383c512478f58f88d95764507d84c899d656a89a"}, + {file = "pylint-2.17.3.tar.gz", hash = "sha256:761907349e699f8afdcd56c4fe02f3021ab5b3a0fc26d19a9bfdc66c7d0d5cd5"}, ] [package.dependencies] -astroid = ">=2.15.2,<=2.17.0-dev0" +astroid = ">=2.15.4,<=2.17.0-dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ {version = ">=0.2", markers = "python_version < \"3.11\""}, @@ -3363,14 +3382,14 @@ files = [ [[package]] name = "pystac" -version = "1.7.2" +version = "1.7.3" description = "Python library for working with Spatiotemporal Asset Catalog (STAC)." category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "pystac-1.7.2-py3-none-any.whl", hash = "sha256:29c6f053741e2fb942502e33a3c61b5217a286838c50e66ccd06be82fd6bd664"}, - {file = "pystac-1.7.2.tar.gz", hash = "sha256:049e8ece607e872241e872fd1f84dab5a2003b36dcbab9a013542f96d1b6c95d"}, + {file = "pystac-1.7.3-py3-none-any.whl", hash = "sha256:2b1b5e11b995e443376ca1d195609d95723f690c8d192604bc00091fcdf52e4c"}, + {file = "pystac-1.7.3.tar.gz", hash = "sha256:6848074fad6665ac631abd62c692bb868de37379615db90f4d913dca37f844ce"}, ] [package.dependencies] @@ -3933,14 +3952,14 @@ stats = ["scipy (>=1.3)", "statsmodels (>=0.10)"] [[package]] name = "setuptools" -version = "67.7.1" +version = "67.7.2" description = "Easily download, build, install, upgrade, and uninstall Python packages" category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "setuptools-67.7.1-py3-none-any.whl", hash = "sha256:6f0839fbdb7e3cfef1fc38d7954f5c1c26bf4eebb155a55c9bf8faf997b9fb67"}, - {file = "setuptools-67.7.1.tar.gz", hash = "sha256:bb16732e8eb928922eabaa022f881ae2b7cdcfaf9993ef1f5e841a96d32b8e0c"}, + {file = "setuptools-67.7.2-py3-none-any.whl", hash = "sha256:23aaf86b85ca52ceb801d32703f12d77517b2556af839621c641fca11287952b"}, + {file = "setuptools-67.7.2.tar.gz", hash = "sha256:f104fa03692a2602fa0fec6c6a9e63b6c8a968de13e17c026957dd1f53d80990"}, ] [package.extras] @@ -4076,21 +4095,21 @@ files = [ [[package]] name = "sphinx" -version = "6.1.3" +version = "6.2.1" description = "Python documentation generator" category = "dev" optional = false python-versions = ">=3.8" files = [ - {file = "Sphinx-6.1.3.tar.gz", hash = "sha256:0dac3b698538ffef41716cf97ba26c1c7788dba73ce6f150c1ff5b4720786dd2"}, - {file = "sphinx-6.1.3-py3-none-any.whl", hash = "sha256:807d1cb3d6be87eb78a381c3e70ebd8d346b9a25f3753e9947e866b2786865fc"}, + {file = "Sphinx-6.2.1.tar.gz", hash = "sha256:6d56a34697bb749ffa0152feafc4b19836c755d90a7c59b72bc7dfd371b9cc6b"}, + {file = "sphinx-6.2.1-py3-none-any.whl", hash = "sha256:97787ff1fa3256a3eef9eda523a63dbf299f7b47e053cfcf684a1c2a8380c912"}, ] [package.dependencies] alabaster = ">=0.7,<0.8" babel = ">=2.9" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} -docutils = ">=0.18,<0.20" +docutils = ">=0.18.1,<0.20" imagesize = ">=1.3" importlib-metadata = {version = ">=4.8", markers = "python_version < \"3.10\""} Jinja2 = ">=3.0" @@ -4108,7 +4127,7 @@ sphinxcontrib-serializinghtml = ">=1.1.5" [package.extras] docs = ["sphinxcontrib-websupport"] lint = ["docutils-stubs", "flake8 (>=3.5.0)", "flake8-simplify", "isort", "mypy (>=0.990)", "ruff", "sphinx-lint", "types-requests"] -test = ["cython", "html5lib", "pytest (>=4.6)"] +test = ["cython", "filelock", "html5lib", "pytest (>=4.6)"] [[package]] name = "sphinx-book-theme" @@ -4456,23 +4475,23 @@ files = [ [[package]] name = "tornado" -version = "6.3" +version = "6.3.1" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." category = "dev" optional = false python-versions = ">= 3.8" files = [ - {file = "tornado-6.3-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:6cfff1e9c15c79e106b8352269d201f8fc0815914a6260f3893ca18b724ea94b"}, - {file = "tornado-6.3-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:6164571f5b9f73143d1334df4584cb9ac86d20c461e17b6c189a19ead8bb93c1"}, - {file = "tornado-6.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4546003dc8b5733489139d3bff5fa6a0211be505faf819bd9970e7c2b32e8122"}, - {file = "tornado-6.3-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c659ab04d5aa477dbe44152c67d93f3ad3243b992d94f795ca1d5c73c37337ce"}, - {file = "tornado-6.3-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:912df5712024564e362ecce43c8d5862e14c78c8dd3846c9d889d44fbd7f4951"}, - {file = "tornado-6.3-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:c37b6a384d54ce6a31168d40ab21ad2591ddaf34973075cc0cad154402ecd9e8"}, - {file = "tornado-6.3-cp38-abi3-musllinux_1_1_i686.whl", hash = "sha256:c9114a61a4588c09065b9996ae05462350d17160b92b9bf9a1e93689cc0424dc"}, - {file = "tornado-6.3-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:4d349846931557b7ec92f224b5d598b160e2ba26ae1812480b42e9622c884bf7"}, - {file = "tornado-6.3-cp38-abi3-win32.whl", hash = "sha256:d7b737e18f701de3e4a3b0824260b4d740e4d60607b8089bb80e80ffd464780e"}, - {file = "tornado-6.3-cp38-abi3-win_amd64.whl", hash = "sha256:720f53e6367b38190ae7fa398c25c086c69d88b3c6535bd6021a126b727fb5cd"}, - {file = "tornado-6.3.tar.gz", hash = "sha256:d68f3192936ff2c4add04dc21a436a43b4408d466746b78bb2b9d0a53a18683f"}, + {file = "tornado-6.3.1-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:db181eb3df8738613ff0a26f49e1b394aade05034b01200a63e9662f347d4415"}, + {file = "tornado-6.3.1-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:b4e7b956f9b5e6f9feb643ea04f07e7c6b49301e03e0023eedb01fa8cf52f579"}, + {file = "tornado-6.3.1-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9661aa8bc0e9d83d757cd95b6f6d1ece8ca9fd1ccdd34db2de381e25bf818233"}, + {file = "tornado-6.3.1-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:81c17e0cc396908a5e25dc8e9c5e4936e6dfd544c9290be48bd054c79bcad51e"}, + {file = "tornado-6.3.1-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a27a1cfa9997923f80bdd962b3aab048ac486ad8cfb2f237964f8ab7f7eb824b"}, + {file = "tornado-6.3.1-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:d7117f3c7ba5d05813b17a1f04efc8e108a1b811ccfddd9134cc68553c414864"}, + {file = "tornado-6.3.1-cp38-abi3-musllinux_1_1_i686.whl", hash = "sha256:ffdce65a281fd708da5a9def3bfb8f364766847fa7ed806821a69094c9629e8a"}, + {file = "tornado-6.3.1-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:90f569a35a8ec19bde53aa596952071f445da678ec8596af763b9b9ce07605e6"}, + {file = "tornado-6.3.1-cp38-abi3-win32.whl", hash = "sha256:3455133b9ff262fd0a75630af0a8ee13564f25fb4fd3d9ce239b8a7d3d027bf8"}, + {file = "tornado-6.3.1-cp38-abi3-win_amd64.whl", hash = "sha256:1285f0691143f7ab97150831455d4db17a267b59649f7bd9700282cba3d5e771"}, + {file = "tornado-6.3.1.tar.gz", hash = "sha256:5e2f49ad371595957c50e42dd7e5c14d64a6843a3cf27352b69c706d1b5918af"}, ] [[package]] @@ -4760,86 +4779,86 @@ viz = ["matplotlib", "nc-time-axis", "seaborn"] [[package]] name = "yarl" -version = "1.9.1" +version = "1.9.2" description = "Yet another URL library" category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "yarl-1.9.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e124b283a04cc06d22443cae536f93d86cd55108fa369f22b8fe1f2288b2fe1c"}, - {file = "yarl-1.9.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:56956b13ec275de31fe4fb991510b735c4fb3e1b01600528c952b9ac90464430"}, - {file = "yarl-1.9.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ecaa5755a39f6f26079bf13f336c67af589c222d76b53cd3824d3b684b84d1f1"}, - {file = "yarl-1.9.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:92a101f6d5a9464e86092adc36cd40ef23d18a25bfb1eb32eaeb62edc22776bb"}, - {file = "yarl-1.9.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:92e37999e36f9f3ded78e9d839face6baa2abdf9344ea8ed2735f495736159de"}, - {file = "yarl-1.9.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ef7e2f6c47c41e234600a02e1356b799761485834fe35d4706b0094cb3a587ee"}, - {file = "yarl-1.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d7a0075a55380b19aa43b9e8056e128b058460d71d75018a4f9d60ace01e78c"}, - {file = "yarl-1.9.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2f01351b7809182822b21061d2a4728b7b9e08f4585ba90ee4c5c4d3faa0812"}, - {file = "yarl-1.9.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6cf47fe9df9b1ededc77e492581cdb6890a975ad96b4172e1834f1b8ba0fc3ba"}, - {file = "yarl-1.9.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:098bdc06ffb4db39c73883325b8c738610199f5f12e85339afedf07e912a39af"}, - {file = "yarl-1.9.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:6cdb47cbbacae8e1d7941b0d504d0235d686090eef5212ca2450525905e9cf02"}, - {file = "yarl-1.9.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:73a4b46689f2d59c8ec6b71c9a0cdced4e7863dd6eb98a8c30ea610e191f9e1c"}, - {file = "yarl-1.9.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:65d952e464df950eed32bb5dcbc1b4443c7c2de4d7abd7265b45b1b3b27f5fa2"}, - {file = "yarl-1.9.1-cp310-cp310-win32.whl", hash = "sha256:39a7a9108e9fc633ae381562f8f0355bb4ba00355218b5fb19cf5263fcdbfa68"}, - {file = "yarl-1.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:b63d41e0eecf3e3070d44f97456cf351fff7cb960e97ecb60a936b877ff0b4f6"}, - {file = "yarl-1.9.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4295790981630c4dab9d6de7b0f555a4c8defe3ed7704a8e9e595a321e59a0f5"}, - {file = "yarl-1.9.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b2b2382d59dec0f1fdca18ea429c4c4cee280d5e0dbc841180abb82e188cf6e9"}, - {file = "yarl-1.9.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:575975d28795a61e82c85f114c02333ca54cbd325fd4e4b27598c9832aa732e7"}, - {file = "yarl-1.9.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9bb794882818fae20ff65348985fdf143ea6dfaf6413814db1848120db8be33e"}, - {file = "yarl-1.9.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:89da1fd6068553e3a333011cc17ad91c414b2100c32579ddb51517edc768b49c"}, - {file = "yarl-1.9.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4d817593d345fefda2fae877accc8a0d9f47ada57086da6125fa02a62f6d1a94"}, - {file = "yarl-1.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85aa6fd779e194901386709e0eedd45710b68af2709f82a84839c44314b68c10"}, - {file = "yarl-1.9.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eed9827033b7f67ad12cb70bd0cb59d36029144a7906694317c2dbf5c9eb5ddd"}, - {file = "yarl-1.9.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:df747104ef27ab1aa9a1145064fa9ea26ad8cf24bfcbdba7db7abf0f8b3676b9"}, - {file = "yarl-1.9.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:efec77851231410125cb5be04ec96fa4a075ca637f415a1f2d2c900b09032a8a"}, - {file = "yarl-1.9.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:d5c407e530cf2979ea383885516ae79cc4f3c3530623acf5e42daf521f5c2564"}, - {file = "yarl-1.9.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:f76edb386178a54ea7ceffa798cb830c3c22ab50ea10dfb25dc952b04848295f"}, - {file = "yarl-1.9.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:75676110bce59944dd48fd18d0449bd37eaeb311b38a0c768f7670864b5f8b68"}, - {file = "yarl-1.9.1-cp311-cp311-win32.whl", hash = "sha256:9ba5a18c4fbd408fe49dc5da85478a76bc75c1ce912d7fd7b43ed5297c4403e1"}, - {file = "yarl-1.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:b20a5ddc4e243cbaa54886bfe9af6ffc4ba4ef58f17f1bb691e973eb65bba84d"}, - {file = "yarl-1.9.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:791357d537a09a194f92b834f28c98d074e7297bac0a8f1d5b458a906cafa17c"}, - {file = "yarl-1.9.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89099c887338608da935ba8bee027564a94f852ac40e472de15d8309517ad5fe"}, - {file = "yarl-1.9.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:395ea180257a3742d09dcc5071739682a95f7874270ebe3982d6696caec75be0"}, - {file = "yarl-1.9.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:90ebaf448b5f048352ec7c76cb8d452df30c27cb6b8627dfaa9cf742a14f141a"}, - {file = "yarl-1.9.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f878a78ed2ccfbd973cab46dd0933ecd704787724db23979e5731674d76eb36f"}, - {file = "yarl-1.9.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:74390c2318d066962500045aa145f5412169bce842e734b8c3e6e3750ad5b817"}, - {file = "yarl-1.9.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:f8e73f526140c1c32f5fca4cd0bc3b511a1abcd948f45b2a38a95e4edb76ca72"}, - {file = "yarl-1.9.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:ac8e593df1fbea820da7676929f821a0c7c2cecb8477d010254ce8ed54328ea8"}, - {file = "yarl-1.9.1-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:01cf88cb80411978a14aa49980968c1aeb7c18a90ac978c778250dd234d8e0ba"}, - {file = "yarl-1.9.1-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:97d76a3128f48fa1c721ef8a50e2c2f549296b2402dc8a8cde12ff60ed922f53"}, - {file = "yarl-1.9.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:01a073c9175481dfed6b40704a1b67af5a9435fc4a58a27d35fd6b303469b0c7"}, - {file = "yarl-1.9.1-cp37-cp37m-win32.whl", hash = "sha256:ecad20c3ef57c513dce22f58256361d10550a89e8eaa81d5082f36f8af305375"}, - {file = "yarl-1.9.1-cp37-cp37m-win_amd64.whl", hash = "sha256:f5bcb80006efe9bf9f49ae89711253dd06df8053ff814622112a9219346566a7"}, - {file = "yarl-1.9.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:e7ddebeabf384099814353a2956ed3ab5dbaa6830cc7005f985fcb03b5338f05"}, - {file = "yarl-1.9.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:13a1ad1f35839b3bb5226f59816b71e243d95d623f5b392efaf8820ddb2b3cd5"}, - {file = "yarl-1.9.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f0cd87949d619157a0482c6c14e5011f8bf2bc0b91cb5087414d9331f4ef02dd"}, - {file = "yarl-1.9.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d21887cbcf6a3cc5951662d8222bc9c04e1b1d98eebe3bb659c3a04ed49b0eec"}, - {file = "yarl-1.9.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4764114e261fe49d5df9b316b3221493d177247825c735b2aae77bc2e340d800"}, - {file = "yarl-1.9.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3abe37fd89a93ebe0010417ca671f422fa6fcffec54698f623b09f46b4d4a512"}, - {file = "yarl-1.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9fe3a1c073ab80a28a06f41d2b623723046709ed29faf2c56bea41848597d86"}, - {file = "yarl-1.9.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b3b5f8da07a21f2e57551f88a6709c2d340866146cf7351e5207623cfe8aad16"}, - {file = "yarl-1.9.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:88f6413ff5edfb9609e2769e32ce87a62353e66e75d264bf0eaad26fb9daa8f2"}, - {file = "yarl-1.9.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b5d5fb6c94b620a7066a3adb7c246c87970f453813979818e4707ac32ce4d7bd"}, - {file = "yarl-1.9.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:f206adb89424dca4a4d0b31981869700e44cd62742527e26d6b15a510dd410a2"}, - {file = "yarl-1.9.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:44fa6158e6b4b8ccfa2872c3900a226b29e8ce543ce3e48aadc99816afa8874d"}, - {file = "yarl-1.9.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:08c8599d6aa8a24425f8635f6c06fa8726afe3be01c8e53e236f519bcfa5db5b"}, - {file = "yarl-1.9.1-cp38-cp38-win32.whl", hash = "sha256:6b09cce412386ea9b4dda965d8e78d04ac5b5792b2fa9cced3258ec69c7d1c16"}, - {file = "yarl-1.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:09c56a32c26e24ef98d5757c5064e252836f621f9a8b42737773aa92936b8e08"}, - {file = "yarl-1.9.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b86e98c3021b7e2740d8719bf074301361bf2f51221ca2765b7a58afbfbd9042"}, - {file = "yarl-1.9.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5faf3ec98747318cb980aaf9addf769da68a66431fc203a373d95d7ee9c1fbb4"}, - {file = "yarl-1.9.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a21789bdf28549d4eb1de6910cabc762c9f6ae3eef85efc1958197c1c6ef853b"}, - {file = "yarl-1.9.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a8b8d4b478a9862447daef4cafc89d87ea4ed958672f1d11db7732b77ead49cc"}, - {file = "yarl-1.9.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:307a782736ebf994e7600dcaeea3b3113083584da567272f2075f1540919d6b3"}, - {file = "yarl-1.9.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:46c4010de941e2e1365c07fb4418ddca10fcff56305a6067f5ae857f8c98f3a7"}, - {file = "yarl-1.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bab67d041c78e305ff3eef5e549304d843bd9b603c8855b68484ee663374ce15"}, - {file = "yarl-1.9.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1baf8cdaaab65d9ccedbf8748d626ad648b74b0a4d033e356a2f3024709fb82f"}, - {file = "yarl-1.9.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:27efc2e324f72df02818cd72d7674b1f28b80ab49f33a94f37c6473c8166ce49"}, - {file = "yarl-1.9.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ca14b84091700ae7c1fcd3a6000bd4ec1a3035009b8bcb94f246741ca840bb22"}, - {file = "yarl-1.9.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:c3ca8d71b23bdf164b36d06df2298ec8a5bd3de42b17bf3e0e8e6a7489195f2c"}, - {file = "yarl-1.9.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:8c72a1dc7e2ea882cd3df0417c808ad3b69e559acdc43f3b096d67f2fb801ada"}, - {file = "yarl-1.9.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d966cd59df9a4b218480562e8daab39e87e746b78a96add51a3ab01636fc4291"}, - {file = "yarl-1.9.1-cp39-cp39-win32.whl", hash = "sha256:518a92a34c741836a315150460b5c1c71ae782d569eabd7acf53372e437709f7"}, - {file = "yarl-1.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:78755ce43b6e827e65ec0c68be832f86d059fcf05d4b33562745ebcfa91b26b1"}, - {file = "yarl-1.9.1.tar.gz", hash = "sha256:5ce0bcab7ec759062c818d73837644cde567ab8aa1e0d6c45db38dfb7c284441"}, + {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8c2ad583743d16ddbdf6bb14b5cd76bf43b0d0006e918809d5d4ddf7bde8dd82"}, + {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:82aa6264b36c50acfb2424ad5ca537a2060ab6de158a5bd2a72a032cc75b9eb8"}, + {file = "yarl-1.9.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c0c77533b5ed4bcc38e943178ccae29b9bcf48ffd1063f5821192f23a1bd27b9"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee4afac41415d52d53a9833ebae7e32b344be72835bbb589018c9e938045a560"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9bf345c3a4f5ba7f766430f97f9cc1320786f19584acc7086491f45524a551ac"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a96c19c52ff442a808c105901d0bdfd2e28575b3d5f82e2f5fd67e20dc5f4ea"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:891c0e3ec5ec881541f6c5113d8df0315ce5440e244a716b95f2525b7b9f3608"}, + {file = "yarl-1.9.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c3a53ba34a636a256d767c086ceb111358876e1fb6b50dfc4d3f4951d40133d5"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:566185e8ebc0898b11f8026447eacd02e46226716229cea8db37496c8cdd26e0"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2b0738fb871812722a0ac2154be1f049c6223b9f6f22eec352996b69775b36d4"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:32f1d071b3f362c80f1a7d322bfd7b2d11e33d2adf395cc1dd4df36c9c243095"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:e9fdc7ac0d42bc3ea78818557fab03af6181e076a2944f43c38684b4b6bed8e3"}, + {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:56ff08ab5df8429901ebdc5d15941b59f6253393cb5da07b4170beefcf1b2528"}, + {file = "yarl-1.9.2-cp310-cp310-win32.whl", hash = "sha256:8ea48e0a2f931064469bdabca50c2f578b565fc446f302a79ba6cc0ee7f384d3"}, + {file = "yarl-1.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:50f33040f3836e912ed16d212f6cc1efb3231a8a60526a407aeb66c1c1956dde"}, + {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:646d663eb2232d7909e6601f1a9107e66f9791f290a1b3dc7057818fe44fc2b6"}, + {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aff634b15beff8902d1f918012fc2a42e0dbae6f469fce134c8a0dc51ca423bb"}, + {file = "yarl-1.9.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a83503934c6273806aed765035716216cc9ab4e0364f7f066227e1aaea90b8d0"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b25322201585c69abc7b0e89e72790469f7dad90d26754717f3310bfe30331c2"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:22a94666751778629f1ec4280b08eb11815783c63f52092a5953faf73be24191"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ec53a0ea2a80c5cd1ab397925f94bff59222aa3cf9c6da938ce05c9ec20428d"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:159d81f22d7a43e6eabc36d7194cb53f2f15f498dbbfa8edc8a3239350f59fe7"}, + {file = "yarl-1.9.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:832b7e711027c114d79dffb92576acd1bd2decc467dec60e1cac96912602d0e6"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:95d2ecefbcf4e744ea952d073c6922e72ee650ffc79028eb1e320e732898d7e8"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d4e2c6d555e77b37288eaf45b8f60f0737c9efa3452c6c44626a5455aeb250b9"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:783185c75c12a017cc345015ea359cc801c3b29a2966c2655cd12b233bf5a2be"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:b8cc1863402472f16c600e3e93d542b7e7542a540f95c30afd472e8e549fc3f7"}, + {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:822b30a0f22e588b32d3120f6d41e4ed021806418b4c9f0bc3048b8c8cb3f92a"}, + {file = "yarl-1.9.2-cp311-cp311-win32.whl", hash = "sha256:a60347f234c2212a9f0361955007fcf4033a75bf600a33c88a0a8e91af77c0e8"}, + {file = "yarl-1.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:be6b3fdec5c62f2a67cb3f8c6dbf56bbf3f61c0f046f84645cd1ca73532ea051"}, + {file = "yarl-1.9.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:38a3928ae37558bc1b559f67410df446d1fbfa87318b124bf5032c31e3447b74"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac9bb4c5ce3975aeac288cfcb5061ce60e0d14d92209e780c93954076c7c4367"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3da8a678ca8b96c8606bbb8bfacd99a12ad5dd288bc6f7979baddd62f71c63ef"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13414591ff516e04fcdee8dc051c13fd3db13b673c7a4cb1350e6b2ad9639ad3"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf74d08542c3a9ea97bb8f343d4fcbd4d8f91bba5ec9d5d7f792dbe727f88938"}, + {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e7221580dc1db478464cfeef9b03b95c5852cc22894e418562997df0d074ccc"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:494053246b119b041960ddcd20fd76224149cfea8ed8777b687358727911dd33"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:52a25809fcbecfc63ac9ba0c0fb586f90837f5425edfd1ec9f3372b119585e45"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:e65610c5792870d45d7b68c677681376fcf9cc1c289f23e8e8b39c1485384185"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:1b1bba902cba32cdec51fca038fd53f8beee88b77efc373968d1ed021024cc04"}, + {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:662e6016409828ee910f5d9602a2729a8a57d74b163c89a837de3fea050c7582"}, + {file = "yarl-1.9.2-cp37-cp37m-win32.whl", hash = "sha256:f364d3480bffd3aa566e886587eaca7c8c04d74f6e8933f3f2c996b7f09bee1b"}, + {file = "yarl-1.9.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6a5883464143ab3ae9ba68daae8e7c5c95b969462bbe42e2464d60e7e2698368"}, + {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5610f80cf43b6202e2c33ba3ec2ee0a2884f8f423c8f4f62906731d876ef4fac"}, + {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b9a4e67ad7b646cd6f0938c7ebfd60e481b7410f574c560e455e938d2da8e0f4"}, + {file = "yarl-1.9.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:83fcc480d7549ccebe9415d96d9263e2d4226798c37ebd18c930fce43dfb9574"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fcd436ea16fee7d4207c045b1e340020e58a2597301cfbcfdbe5abd2356c2fb"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84e0b1599334b1e1478db01b756e55937d4614f8654311eb26012091be109d59"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3458a24e4ea3fd8930e934c129b676c27452e4ebda80fbe47b56d8c6c7a63a9e"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:838162460b3a08987546e881a2bfa573960bb559dfa739e7800ceeec92e64417"}, + {file = "yarl-1.9.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f4e2d08f07a3d7d3e12549052eb5ad3eab1c349c53ac51c209a0e5991bbada78"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:de119f56f3c5f0e2fb4dee508531a32b069a5f2c6e827b272d1e0ff5ac040333"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:149ddea5abf329752ea5051b61bd6c1d979e13fbf122d3a1f9f0c8be6cb6f63c"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:674ca19cbee4a82c9f54e0d1eee28116e63bc6fd1e96c43031d11cbab8b2afd5"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:9b3152f2f5677b997ae6c804b73da05a39daa6a9e85a512e0e6823d81cdad7cc"}, + {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5415d5a4b080dc9612b1b63cba008db84e908b95848369aa1da3686ae27b6d2b"}, + {file = "yarl-1.9.2-cp38-cp38-win32.whl", hash = "sha256:f7a3d8146575e08c29ed1cd287068e6d02f1c7bdff8970db96683b9591b86ee7"}, + {file = "yarl-1.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:63c48f6cef34e6319a74c727376e95626f84ea091f92c0250a98e53e62c77c72"}, + {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:75df5ef94c3fdc393c6b19d80e6ef1ecc9ae2f4263c09cacb178d871c02a5ba9"}, + {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c027a6e96ef77d401d8d5a5c8d6bc478e8042f1e448272e8d9752cb0aff8b5c8"}, + {file = "yarl-1.9.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f3b078dbe227f79be488ffcfc7a9edb3409d018e0952cf13f15fd6512847f3f7"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59723a029760079b7d991a401386390c4be5bfec1e7dd83e25a6a0881859e716"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b03917871bf859a81ccb180c9a2e6c1e04d2f6a51d953e6a5cdd70c93d4e5a2a"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c1012fa63eb6c032f3ce5d2171c267992ae0c00b9e164efe4d73db818465fac3"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a74dcbfe780e62f4b5a062714576f16c2f3493a0394e555ab141bf0d746bb955"}, + {file = "yarl-1.9.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8c56986609b057b4839968ba901944af91b8e92f1725d1a2d77cbac6972b9ed1"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2c315df3293cd521033533d242d15eab26583360b58f7ee5d9565f15fee1bef4"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:b7232f8dfbd225d57340e441d8caf8652a6acd06b389ea2d3222b8bc89cbfca6"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:53338749febd28935d55b41bf0bcc79d634881195a39f6b2f767870b72514caf"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:066c163aec9d3d073dc9ffe5dd3ad05069bcb03fcaab8d221290ba99f9f69ee3"}, + {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8288d7cd28f8119b07dd49b7230d6b4562f9b61ee9a4ab02221060d21136be80"}, + {file = "yarl-1.9.2-cp39-cp39-win32.whl", hash = "sha256:b124e2a6d223b65ba8768d5706d103280914d61f5cae3afbc50fc3dfcc016623"}, + {file = "yarl-1.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:61016e7d582bc46a5378ffdd02cd0314fb8ba52f40f9cf4d9a5e7dbef88dee18"}, + {file = "yarl-1.9.2.tar.gz", hash = "sha256:04ab9d4b9f587c06d801c2abfe9317b77cdf996c65a90d5e84ecc45010823571"}, ] [package.dependencies] @@ -4865,4 +4884,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.12" -content-hash = "524b89e698d9937495828ff20d1a2413bc3873ae3ce3fd3840a9fcad0f9fc839" +content-hash = "0f9a52d72e8daee02fd9d5c3c397c8abc6b3376deaa0d468e6b3c724449449a7" diff --git a/pyproject.toml b/pyproject.toml index 58c0dc7f..6d9a502d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,11 +19,12 @@ python = ">=3.8,<3.12" jax = ">=0.4.1" jaxlib = "^0.4.6" optax = "^0.1.4" -jaxtyping = "^0.2.14" +jaxtyping = "^0.2.15" tqdm = "^4.65.0" simple-pytree = "^0.1.7" tensorflow-probability = "^0.19.0" orbax-checkpoint = "^0.2.0" +beartype = "^0.13.1" [tool.poetry.group.test.dependencies] pytest = "^7.2.2" @@ -118,4 +119,4 @@ source = ["src"] output = "reports/coverage.xml" [tool.poetry-dynamic-versioning] -enable = true \ No newline at end of file +enable = true diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..3be32572 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,5 @@ +from jaxtyping import install_import_hook + +# import gpjax within import hook to apply beartype everywhere, before running tests +with install_import_hook("gpjax", "beartype.beartype"): + import gpjax diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 46e06b0c..6d685b24 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -15,6 +15,12 @@ from dataclasses import is_dataclass +try: + import beartype + ValidationErrors = (ValueError, beartype.roar.BeartypeCallHintParamViolation) +except ImportError: + ValidationErrors = ValueError + import jax.numpy as jnp import jax.tree_util as jtu import pytest @@ -107,7 +113,7 @@ def test_dataset_incorrect_lengths(nx: int, ny: int, out_dim: int, in_dim: int) y = jnp.ones((ny, out_dim)) # Ensure error is raised upon dataset creation - with pytest.raises(ValueError): + with pytest.raises(ValidationErrors): Dataset(X=x, y=y) @@ -120,7 +126,7 @@ def test_2d_inputs(n: int, out_dim: int, in_dim: int) -> None: y = jnp.ones((n,)) # Ensure error is raised upon dataset creation - with pytest.raises(ValueError): + with pytest.raises(ValidationErrors): Dataset(X=x, y=y) # Create dataset where input dimension is incorrectly not 2D @@ -128,7 +134,7 @@ def test_2d_inputs(n: int, out_dim: int, in_dim: int) -> None: y = jnp.ones((n, out_dim)) # Ensure error is raised upon dataset creation - with pytest.raises(ValueError): + with pytest.raises(ValidationErrors): Dataset(X=x, y=y) diff --git a/tests/test_gps.py b/tests/test_gps.py index cb8cc41d..61f687e6 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -13,6 +13,12 @@ # limitations under the License. # ============================================================================== +try: + import beartype + ValidationErrors = (ValueError, beartype.roar.BeartypeCallHintParamViolation) +except ImportError: + ValidationErrors = ValueError + import jax.numpy as jnp import jax.random as jr import pytest @@ -229,13 +235,13 @@ def test_prior_sample_approx(num_datapoints, kernel, mean_function): p.sample_approx(-1, key) with pytest.raises(ValueError): p.sample_approx(0, key) - with pytest.raises(ValueError): + with pytest.raises(ValidationErrors): p.sample_approx(0.5, key) with pytest.raises(ValueError): p.sample_approx(1, key, -10) with pytest.raises(ValueError): p.sample_approx(1, key, 0) - with pytest.raises(ValueError): + with pytest.raises(ValidationErrors): p.sample_approx(1, key, 0.5) sampled_fn = p.sample_approx(1, key, 100) @@ -290,13 +296,13 @@ def test_conjugate_posterior_sample_approx(num_datapoints, kernel, mean_function p.sample_approx(-1, D, key) with pytest.raises(ValueError): p.sample_approx(0, D, key) - with pytest.raises(ValueError): + with pytest.raises(ValidationErrors): p.sample_approx(0.5, D, key) with pytest.raises(ValueError): p.sample_approx(1, D, key, -10) with pytest.raises(ValueError): p.sample_approx(1, D, key, 0) - with pytest.raises(ValueError): + with pytest.raises(ValidationErrors): p.sample_approx(1, D, key, 0.5) sampled_fn = p.sample_approx(1, D, key, 100) diff --git a/tests/test_kernels/test_approximations.py b/tests/test_kernels/test_approximations.py index 10164b43..b768ed12 100644 --- a/tests/test_kernels/test_approximations.py +++ b/tests/test_kernels/test_approximations.py @@ -130,12 +130,12 @@ def test_exactness(kernel): ) def test_value_error(kernel): with pytest.raises(ValueError): - RFF(kernel(), num_basis_fns=10) + RFF(base_kernel=kernel(), num_basis_fns=10) @pytest.mark.parametrize("kernel", [RBF(), Matern12(), Matern32(), Matern52()]) -def stochastic_init(kernel: AbstractKernel): - k1 = RFF(kernel, num_basis_fns=10, key=123) - k2 = RFF(kernel, num_basis_fns=10, key=42) +def test_stochastic_init(kernel: AbstractKernel): + k1 = RFF(base_kernel=kernel, num_basis_fns=10, key=jr.PRNGKey(123)) + k2 = RFF(base_kernel=kernel, num_basis_fns=10, key=jr.PRNGKey(42)) assert (k1.frequencies != k2.frequencies).any() diff --git a/tests/test_kernels/test_non_euclidean.py b/tests/test_kernels/test_non_euclidean.py index 133454d2..f90f2329 100644 --- a/tests/test_kernels/test_non_euclidean.py +++ b/tests/test_kernels/test_non_euclidean.py @@ -13,6 +13,7 @@ # # limitations under the License. # # ============================================================================== +import pytest import jax.numpy as jnp import jax.random as jr import networkx as nx diff --git a/tests/test_kernels/test_nonstationary.py b/tests/test_kernels/test_nonstationary.py index fc4de81f..1f133088 100644 --- a/tests/test_kernels/test_nonstationary.py +++ b/tests/test_kernels/test_nonstationary.py @@ -64,7 +64,7 @@ def test_initialization(self, fields: dict, dim: int) -> None: assert is_dataclass(self.kernel) # Input fields as JAX arrays - fields = {k: jnp.array([v]) for k, v in fields.items()} + fields = {k: jnp.array(v) for k, v in fields.items()} # Test number of dimensions if dim is None: diff --git a/tests/test_kernels/test_stationary.py b/tests/test_kernels/test_stationary.py index e69059b0..5ab5aa93 100644 --- a/tests/test_kernels/test_stationary.py +++ b/tests/test_kernels/test_stationary.py @@ -82,7 +82,7 @@ def test_initialization(self, fields: dict, dim: int) -> None: assert is_dataclass(self.kernel) # Input fields as JAX arrays - fields = {k: jnp.array([v]) for k, v in fields.items()} + fields = {k: jnp.array(v) for k, v in fields.items()} # Test number of dimensions if dim is None: diff --git a/tests/test_mean_functions.py b/tests/test_mean_functions.py index 537c68bf..8cb0ce29 100644 --- a/tests/test_mean_functions.py +++ b/tests/test_mean_functions.py @@ -1,4 +1,5 @@ import jax +import jax.numpy as jnp import pytest from jaxtyping import Array, Float @@ -13,28 +14,28 @@ def test_abstract() -> None: # Check a "dummy" mean funcion with defined abstract method, `__call__`, can be instantiated. class DummyMeanFunction(AbstractMeanFunction): def __call__(self, x: Float[Array, "D"]) -> Float[Array, "1"]: - return jax.numpy.array([1.0]) + return jnp.array([1.0]) mf = DummyMeanFunction() assert isinstance(mf, AbstractMeanFunction) - assert (mf(jax.numpy.array([1.0])) == jax.numpy.array([1.0])).all() - assert (mf(jax.numpy.array([2.0, 3.0])) == jax.numpy.array([1.0])).all() + assert (mf(jnp.array([1.0])) == jnp.array([1.0])).all() + assert (mf(jnp.array([2.0, 3.0])) == jnp.array([1.0])).all() @pytest.mark.parametrize( - "constant", [jax.numpy.array([0.0]), jax.numpy.array([1.0]), jax.numpy.array([3.0])] + "constant", [jnp.array([0.0]), jnp.array([1.0]), jnp.array([3.0])] ) def test_constant(constant: Float[Array, "Q"]) -> None: mf = Constant(constant=constant) assert isinstance(mf, AbstractMeanFunction) - assert (mf(jax.numpy.array([1.0])) == constant).all() - assert (mf(jax.numpy.array([2.0, 3.0])) == constant).all() + assert (mf(jnp.array([[1.0]])) == jnp.array([constant])).all() + assert (mf(jnp.array([[2.0, 3.0]])) == jnp.array([constant])).all() assert ( - jax.vmap(mf)(jax.numpy.array([[1.0], [2.0]])) - == jax.numpy.array([constant, constant]) + mf(jnp.array([[1.0], [2.0]])) + == jnp.array([constant, constant]) ).all() assert ( - jax.vmap(mf)(jax.numpy.array([[1.0, 2.0], [3.0, 4.0]])) - == jax.numpy.array([constant, constant]) + mf(jnp.array([[1.0, 2.0], [3.0, 4.0]])) + == jnp.array([constant, constant]) ).all() diff --git a/tests/test_objectives.py b/tests/test_objectives.py index 542ac2d0..da666046 100644 --- a/tests/test_objectives.py +++ b/tests/test_objectives.py @@ -23,7 +23,7 @@ def build_data(num_datapoints: int, num_dims: int, key, binary: bool): 0.5 * jnp.sign( jnp.cos( - 3 * x[:, 1].reshape(-1, 1) + 3 * x[:, 0].reshape(-1, 1) + jr.normal(key, shape=(num_datapoints, 1)) * 0.05 ) ) @@ -31,7 +31,7 @@ def build_data(num_datapoints: int, num_dims: int, key, binary: bool): ) else: y = ( - jnp.sin(x[:, 1]).reshape(-1, 1) + jnp.sin(x[:, 0]).reshape(-1, 1) + jr.normal(key=key, shape=(num_datapoints, 1)) * 0.1 ) D = Dataset(X=x, y=y) diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index d06041f9..7cab00e9 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -25,6 +25,7 @@ import jax.tree_util as jtu import gpjax as gpx +from gpjax.gps import AbstractPosterior from gpjax.variational_families import (AbstractVariationalFamily, CollapsedVariationalGaussian, ExpectationVariationalGaussian, @@ -44,12 +45,16 @@ def test_abstract_variational_family(): AbstractVariationalFamily() # Create a dummy variational family class with abstract methods implemented. + class DummyPosterior: + @property + def __class__(self) -> type: return AbstractPosterior + class DummyVariationalFamily(AbstractVariationalFamily): def predict(self, x: Float[Array, "N D"]) -> tfd.Distribution: return tfd.MultivariateNormalDiag(loc=x) # Test that the dummy variational family can be instantiated. - dummy_variational_family = DummyVariationalFamily(posterior=None) + dummy_variational_family = DummyVariationalFamily(posterior=DummyPosterior()) assert isinstance(dummy_variational_family, AbstractVariationalFamily) @@ -239,4 +244,4 @@ def test_collapsed_variational_gaussian( for l1, l2 in zip(jtu.tree_leaves(variational_family), true_leaves): assert l1.shape == l2.shape - assert (l1 == l2).all() \ No newline at end of file + assert (l1 == l2).all()