diff --git a/.circleci/config.yml b/.circleci/config.yml index 4a691e71..2e73a4ea 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,7 +1,7 @@ version: 2.1 orbs: - python: circleci/python@2.1.1 + # python: circleci/python@2.1.1 codecov: codecov/codecov@3.2.2 commands: @@ -51,7 +51,6 @@ commands: - run: name: Upload to PyPI command: twine upload dist/* -r << parameters.pkgname >> --verbose - install_pandoc: description: "Install pandoc" parameters: @@ -83,15 +82,12 @@ jobs: resource_class: large steps: - checkout - - restore_cache: - keys: - - pip-cache - run: - name: Update pip - command: pip install --upgrade pip - - python/install-packages: - pkg-manager: pip-dist - path-args: .[dev] + name: Install dependencies + command: | + pip install --upgrade pip + pip install -r requirements/dev.txt + pip install -e . - run: name: Run tests command: | @@ -103,10 +99,6 @@ jobs: curl -Os https://uploader.codecov.io/v0.1.0_4653/linux/codecov chmod +x codecov ./codecov -t ${CODECOV_TOKEN} - - save_cache: - key: pip-cache - paths: - - ~/.cache/pip - store_test_results: path: test-results - store_artifacts: diff --git a/README.md b/README.md index 0e1f60ec..ac5def1c 100644 --- a/README.md +++ b/README.md @@ -123,7 +123,7 @@ parameter_state = gpx.initialise(posterior, key=key) Finally, we run an optimisation loop using the Adam optimiser via the `fit` callable. ```python -inference_state = gpx.fit(mll, parameter_state, opt, n_iters=500) +inference_state = gpx.fit(mll, parameter_state, opt, num_iters=500) ``` ## 3. Making predictions diff --git a/docs/README.md b/docs/README.md index 986a3132..c06b6047 100644 --- a/docs/README.md +++ b/docs/README.md @@ -41,7 +41,6 @@ description and a code example. The docstring is concluded with a description of the objects attributes with corresponding types. ```python -@dataclass class Prior(AbstractPrior): """A Gaussian process prior object. The GP is parameterised by a `mean `_ @@ -78,9 +77,4 @@ class Prior(AbstractPrior): ### Documentation syntax A helpful cheatsheet for writing restructured text can be found -[here](https://github.com/ralsina/rst-cheatsheet/blob/master/rst-cheatsheet.rst). In addition to that, we adopt the following convention when documenting -`dataclass` objects. - -* Class attributes should be specified using the `Attributes:` tag. -* Method argument should be specified using the `Args:` tags. -* All attributes and arguments should have types. +[here](https://github.com/ralsina/rst-cheatsheet/blob/master/rst-cheatsheet.rst). diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..36841b62 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,85 @@ +# Where to find the docs + +The GPJax documentation can be found here: +https://gpjax.readthedocs.io/en/latest/ + +# How to build the docs + +1. Install the requirements using `pip install -r docs/requirements.txt` +2. Make sure `pandoc` is installed +3. Run the make script `make html` + +The corresponding HTML files can then be found in `docs/_build/html/`. + +# How to write code documentation + +Our documentation it is written in ReStructuredText for Sphinx. This is a +meta-language that is compiled into online documentation. For more details see +[Sphinx's documentation](https://www.sphinx-doc.org/en/master/usage/restructuredtext/index.html). +As a result, our docstrings adhere to a specific syntax that has to be kept in +mind. Below we provide some guidelines. + +## How much information to put in a docstring + +A docstring should be informative. If in doubt, then it is best to add more +information to a docstring than less. Many users will skim documentation, so +please ensure the opening sentence or two of a docstring contains the core +information. Adding examples and mathematical descriptions to documentation is +highly desirable. + +We are making an active effort within GPJax to improve our documentation. If you +spot any areas where there is missing information within the existing +documentation, then please either raise an issue or +[create a pull request](https://gpjax.readthedocs.io/en/latest/contributing.html). + +## An example docstring + +An example docstring that adheres the principles of GPJax is given below. +The docstring contains a simple, snappy introduction with links to auxillary +components. More detail is then provided in the form of a mathematical +description and a code example. The docstring is concluded with a description +of the objects attributes with corresponding types. + +```python +class Prior(AbstractPrior): + """A Gaussian process prior object. The GP is parameterised by a + `mean `_ + and `kernel `_ function. + + A Gaussian process prior parameterised by a mean function :math:`m(\\cdot)` and a kernel + function :math:`k(\\cdot, \\cdot)` is given by + + .. math:: + + p(f(\\cdot)) = \mathcal{GP}(m(\\cdot), k(\\cdot, \\cdot)). + + To invoke a ``Prior`` distribution, only a kernel function is required. By default, + the mean function will be set to zero. In general, this assumption will be reasonable + assuming the data being modelled has been centred. + + Example: + >>> import gpjax as gpx + >>> + >>> kernel = gpx.kernels.RBF() + >>> prior = gpx.Prior(kernel = kernel) + + Attributes: + kernel (Kernel): The kernel function used to parameterise the prior. + mean_function (MeanFunction): The mean function used to parameterise the prior. Defaults to zero. + name (str): The name of the GP prior. Defaults to "GP prior". + """ + + kernel: Kernel + mean_function: Optional[AbstractMeanFunction] = Zero() + name: Optional[str] = "GP prior" +``` + +### Documentation syntax + +A helpful cheatsheet for writing restructured text can be found +[here](https://github.com/ralsina/rst-cheatsheet/blob/master/rst-cheatsheet.rst). In addition to that, we adopt the following convention when documenting +`` objects. + +* Class attributes should be specified using the `Attributes:` tag. +* Method argument should be specified using the `Args:` tags. +* All attributes and arguments should have types. diff --git a/examples/barycentres.pct.py b/examples/barycentres.pct.py index c4f3dbde..035efb70 100644 --- a/examples/barycentres.pct.py +++ b/examples/barycentres.pct.py @@ -115,7 +115,7 @@ def fit_gp(x: jax.Array, y: jax.Array) -> dx.MultivariateNormalTri: objective=negative_mll, parameter_state=parameter_state, optax_optim=optimiser, - n_iters=1000, + num_iters=1000, ) learned_params, training_history = inference_state.unpack() diff --git a/examples/classification.pct.py b/examples/classification.pct.py index c0996daa..dddb32d2 100644 --- a/examples/classification.pct.py +++ b/examples/classification.pct.py @@ -91,7 +91,7 @@ objective=negative_mll, parameter_state=parameter_state, optax_optim=optimiser, - n_iters=1000, + num_iters=1000, ) map_estimate, training_history = inference_state.unpack() diff --git a/examples/collapsed_vi.pct.py b/examples/collapsed_vi.pct.py index 244c4cdf..725ade7a 100644 --- a/examples/collapsed_vi.pct.py +++ b/examples/collapsed_vi.pct.py @@ -109,7 +109,7 @@ objective=negative_elbo, parameter_state=parameter_state, optax_optim=optimiser, - n_iters=2000, + num_iters=2000, ) learned_params, training_history = inference_state.unpack() diff --git a/examples/graph_kernels.pct.py b/examples/graph_kernels.pct.py index dc7b1a2d..6cb52c65 100644 --- a/examples/graph_kernels.pct.py +++ b/examples/graph_kernels.pct.py @@ -137,7 +137,7 @@ objective=negative_mll, parameter_state=parameter_state, optax_optim=optimiser, - n_iters=1000, + num_iters=1000, ) learned_params, training_history = inference_state.unpack() diff --git a/examples/haiku.pct.py b/examples/haiku.pct.py index 166ef065..e98978f5 100644 --- a/examples/haiku.pct.py +++ b/examples/haiku.pct.py @@ -185,7 +185,7 @@ def forward(x): objective=negative_mll, parameter_state=parameter_state, optax_optim=optimiser, - n_iters=2500, + num_iters=2500, ) learned_params, training_history = inference_state.unpack() diff --git a/examples/kernels.pct.py b/examples/kernels.pct.py index a2f4d748..c4408c6d 100644 --- a/examples/kernels.pct.py +++ b/examples/kernels.pct.py @@ -228,28 +228,7 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict: # domain is a circle, this is $2\pi$. Next we define the kernel's `__call__` # function which is a direct implementation of Equation (1). Finally, we define # the Kernel's parameter property which contains just one value $\tau$ that we -# initialise to 4 in the kernel's `__post_init__`. -# -# #### Aside on dataclasses -# -# One can see in the above definition of a `Polar` kernel that we decorated the -# class with a `@dataclass` command. Dataclasses are simply regular classs -# objects in Python, however, much of the boilerplate code has been removed. For -# example, without a `@dataclass` decorator, the instantiation of the above -# `Polar` kernel would be done through -# ```python -# class Polar(jk.kernels.AbstractKernel): -# def __init__(self, period: float = 2*jnp.pi): -# super().__init__() -# self.period = period -# ``` -# As objects become increasingly large and complex, the conciseness of a -# dataclass becomes increasingly attractive. To ensure full compatability with -# Jax, it is crucial that the dataclass decorator is imported from Chex, not -# base Python's `dataclass` module. Functionally, the two objects are identical. -# However, unlike regular Python dataclasses, it is possilbe to apply operations -# such as `jit`, `vmap` and `grad` to the dataclasses given by Chex as they are -# registrered PyTrees. +# initialise to 4 in the kernel's `__init__`. # # # ### Custom Parameter Bijection @@ -312,7 +291,7 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict: objective=negative_mll, parameter_state=parameter_state, optax_optim=optimiser, - n_iters=1000, + num_iters=1000, ) learned_params, training_history = inference_state.unpack() diff --git a/examples/natgrads.pct.py b/examples/natgrads.pct.py index dae45793..c72bc565 100644 --- a/examples/natgrads.pct.py +++ b/examples/natgrads.pct.py @@ -95,7 +95,7 @@ natural_svgp, parameter_state=parameter_state, train_data=D, - n_iters=5000, + num_iters=5000, batch_size=256, key=jr.PRNGKey(42), moment_optim=ox.sgd(0.01), diff --git a/examples/regression.pct.py b/examples/regression.pct.py index f440e137..2b9233df 100644 --- a/examples/regression.pct.py +++ b/examples/regression.pct.py @@ -137,7 +137,7 @@ # # ## Parameter state # -# So far, all of the objects that we've defined have been stateless. To give our model state, we can use the `initialise` function provided in GPJax. Upon calling this, a `ParameterState` dataclass is returned that contains four dictionaries: +# So far, all of the objects that we've defined have been stateless. To give our model state, we can use the `initialise` function provided in GPJax. Upon calling this, a `ParameterState` class is returned that contains four dictionaries: # # | Dictionary | Description | # |---|---| @@ -185,11 +185,11 @@ objective=negative_mll, parameter_state=parameter_state, optax_optim=optimiser, - n_iters=500, + num_iters=500, ) # %% [markdown] -# Similar to the `ParameterState` object above, the returned variable from the `fit` function is a dataclass, namely an `InferenceState` object that contains the parameters' final values and a tracked array of the evaluation of our objective function throughout optimisation. +# Similar to the `ParameterState` object above, the returned variable from the `fit` function is a class, namely an `InferenceState` object that contains the parameters' final values and a tracked array of the evaluation of our objective function throughout optimisation. # %% learned_params, training_history = inference_state.unpack() diff --git a/examples/uncollapsed_vi.pct.py b/examples/uncollapsed_vi.pct.py index 930a50a7..9e03c650 100644 --- a/examples/uncollapsed_vi.pct.py +++ b/examples/uncollapsed_vi.pct.py @@ -171,7 +171,7 @@ parameter_state=parameter_state, train_data=D, optax_optim=optimiser, - n_iters=3000, + num_iters=3000, key=jr.PRNGKey(42), batch_size=128, ) @@ -225,7 +225,7 @@ parameter_state=parameter_state, train_data=D, optax_optim=optimiser, - n_iters=3000, + num_iters=3000, key=jr.PRNGKey(42), batch_size=128, ) diff --git a/examples/yacht.pct.py b/examples/yacht.pct.py index a2fa97a3..d557bfae 100644 --- a/examples/yacht.pct.py +++ b/examples/yacht.pct.py @@ -136,7 +136,7 @@ objective=negative_mll, parameter_state=parameter_state, optax_optim=optimiser, - n_iters=1000, + num_iters=1000, log_rate=50, ) diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 331da223..d51754f5 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -28,7 +28,6 @@ from .likelihoods import Bernoulli, Gaussian from .mean_functions import Constant, Zero from .parameters import constrain, copy_dict_structure, initialise, unconstrain -from .types import Dataset from .variational_families import ( CollapsedVariationalGaussian, ExpectationVariationalGaussian, @@ -36,6 +35,7 @@ VariationalGaussian, WhitenedVariationalGaussian, ) +from .types import Dataset from .variational_inference import CollapsedVI, StochasticVI from . import _version diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index d9e45f77..5a3352e6 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -19,7 +19,8 @@ import jax.numpy as jnp import jax.random as jr import optax as ox -from chex import dataclass, PRNGKey as PRNGKeyType + +from jax.random import KeyArray from jax import lax from jax.experimental import host_callback from jaxtyping import Array, Float @@ -27,22 +28,40 @@ from .natural_gradients import natural_gradients from .parameters import ParameterState, constrain, trainable_params, unconstrain -from jaxutils import Dataset +from jaxutils import Dataset, PyTree from .variational_inference import StochasticVI -@dataclass(frozen=True) -class InferenceState: - """Imutable dataclass for storing optimised parameters and training history.""" +class InferenceState(PyTree): + """Imutable class for storing optimised parameters and training history.""" + + def __init__(self, params: Dict, history: Float[Array, "num_iters"]): + self._params = params + self._history = history + + @property + def params(self) -> Dict: + """Parameters. + + Returns: + Dict: Parameters. + """ + return self._params + + @property + def history(self) -> Float[Array, "num_iters"]: + """Training history. - params: Dict - history: Float[Array, "n_iters"] + Returns: + Float[Array, "num_iters"]: Training history. + """ + return self._history - def unpack(self) -> Tuple[Dict, Float[Array, "n_iters"]]: + def unpack(self) -> Tuple[Dict, Float[Array, "num_iters"]]: """Unpack parameters and training history into a tuple. Returns: - Tuple[Dict, Float[Array, "n_iters"]]: Tuple of parameters and training history. + Tuple[Dict, Float[Array, "num_iters"]]: Tuple of parameters and training history. """ return self.params, self.history @@ -51,7 +70,7 @@ def fit( objective: Callable, parameter_state: ParameterState, optax_optim: ox.GradientTransformation, - n_iters: Optional[int] = 100, + num_iters: Optional[int] = 100, log_rate: Optional[int] = 10, verbose: Optional[bool] = True, ) -> InferenceState: @@ -62,9 +81,9 @@ def fit( objective (Callable): The objective function that we are optimising with respect to. parameter_state (ParameterState): The initial parameter state. optax_optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set. - n_iters (int, optional): The number of optimisation steps to run. Defaults to 100. - log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10. - verbose (bool, optional): Whether to print the training loading bar. Defaults to True. + num_iters (Optional[int]): The number of optimisation steps to run. Defaults to 100. + log_rate (Optional[int]): How frequently the objective function's value should be printed. Defaults to 10. + verbose (Optional[bool]): Whether to print the training loading bar. Defaults to True. Returns: InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. @@ -85,7 +104,7 @@ def loss(params: Dict) -> Float[Array, "1"]: opt_state = optax_optim.init(params) # Iteration loop numbers to scan over - iter_nums = jnp.arange(n_iters) + iter_nums = jnp.arange(num_iters) # Optimisation step def step(carry, iter_num: int): @@ -98,7 +117,7 @@ def step(carry, iter_num: int): # Display progress bar if verbose is True if verbose: - step = progress_bar_scan(n_iters, log_rate)(step) + step = progress_bar_scan(num_iters, log_rate)(step) # Run the optimisation loop (params, _), history = jax.lax.scan(step, (params, opt_state), iter_nums) @@ -114,9 +133,9 @@ def fit_batches( parameter_state: ParameterState, train_data: Dataset, optax_optim: ox.GradientTransformation, - key: PRNGKeyType, + key: KeyArray, batch_size: int, - n_iters: Optional[int] = 100, + num_iters: Optional[int] = 100, log_rate: Optional[int] = 10, verbose: Optional[bool] = True, ) -> InferenceState: @@ -129,11 +148,11 @@ def fit_batches( parameter_state (ParameterState): The parameters for which we would like to minimise our objective function with. train_data (Dataset): The training dataset. optax_optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set. - key (PRNGKeyType): The PRNG key for the mini-batch sampling. - batch_size(int): The batch_size. - n_iters (int, optional): The number of optimisation steps to run. Defaults to 100. - log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10. - verbose (bool, optional): Whether to print the training loading bar. Defaults to True. + key (KeyArray): The PRNG key for the mini-batch sampling. + batch_size (int): The batch_size. + num_iters (Optional[int]): The number of optimisation steps to run. Defaults to 100. + log_rate (Optional[int]): How frequently the objective function's value should be printed. Defaults to 10. + verbose (Optional[bool]): Whether to print the training loading bar. Defaults to True. Returns: InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. @@ -154,8 +173,8 @@ def loss(params: Dict, batch: Dataset) -> Float[Array, "1"]: opt_state = optax_optim.init(params) # Mini-batch random keys and iteration loop numbers to scan over - keys = jr.split(key, n_iters) - iter_nums = jnp.arange(n_iters) + keys = jr.split(key, num_iters) + iter_nums = jnp.arange(num_iters) # Optimisation step def step(carry, iter_num__and__key): @@ -173,7 +192,7 @@ def step(carry, iter_num__and__key): # Display progress bar if verbose is True if verbose: - step = progress_bar_scan(n_iters, log_rate)(step) + step = progress_bar_scan(num_iters, log_rate)(step) # Run the optimisation loop (params, _), history = jax.lax.scan(step, (params, opt_state), (iter_nums, keys)) @@ -184,7 +203,7 @@ def step(carry, iter_num__and__key): return InferenceState(params=params, history=history) -def get_batch(train_data: Dataset, batch_size: int, key: PRNGKeyType) -> Dataset: +def get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset: """Batch the data into mini-batches. Sampling is done with replacement. Args: @@ -208,9 +227,9 @@ def fit_natgrads( train_data: Dataset, moment_optim: ox.GradientTransformation, hyper_optim: ox.GradientTransformation, - key: PRNGKeyType, + key: KeyArray, batch_size: int, - n_iters: Optional[int] = 100, + num_iters: Optional[int] = 100, log_rate: Optional[int] = 10, verbose: Optional[bool] = True, ) -> Dict: @@ -226,14 +245,14 @@ def fit_natgrads( train_data (Dataset): The training dataset. moment_optim (GradientTransformation): The Optax optimiser for the natural gradient updates on the moments. hyper_optim (GradientTransformation): The Optax optimiser for gradient updates on the hyperparameters. - key (PRNGKeyType): The PRNG key for the mini-batch sampling. + key (KeyArray): The PRNG key for the mini-batch sampling. batch_size(int): The batch_size. - n_iters (int, optional): The number of optimisation steps to run. Defaults to 100. - log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10. - verbose (bool, optional): Whether to print the training loading bar. Defaults to True. + num_iters (Optional[int]): The number of optimisation steps to run. Defaults to 100. + log_rate (Optional[int]): How frequently the objective function's value should be printed. Defaults to 10. + verbose (Optional[bool]): Whether to print the training loading bar. Defaults to True. Returns: - InferenceState: A dataclass comprising optimised parameters and training history. + InferenceState: A class comprising optimised parameters and training history. """ params, trainables, bijectors = parameter_state.unpack() @@ -251,8 +270,8 @@ def fit_natgrads( ) # Mini-batch random keys and iteration loop numbers to scan over - keys = jax.random.split(key, n_iters) - iter_nums = jnp.arange(n_iters) + keys = jax.random.split(key, num_iters) + iter_nums = jnp.arange(num_iters) # Optimisation step def step(carry, iter_num__and__key): @@ -276,7 +295,7 @@ def step(carry, iter_num__and__key): # Display progress bar if verbose is True if verbose: - step = progress_bar_scan(n_iters, log_rate)(step) + step = progress_bar_scan(num_iters, log_rate)(step) # Run the optimisation loop (params, _, _), history = jax.lax.scan( @@ -289,15 +308,15 @@ def step(carry, iter_num__and__key): return InferenceState(params=params, history=history) -def progress_bar_scan(n_iters: int, log_rate: int) -> Callable: +def progress_bar_scan(num_iters: int, log_rate: int) -> Callable: """Progress bar for Jax.lax scans (adapted from https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/).""" tqdm_bars = {} - remainder = n_iters % log_rate + remainder = num_iters % log_rate def _define_tqdm(args: Any, transform: Any) -> None: """Define a tqdm progress bar.""" - tqdm_bars[0] = tqdm(range(n_iters)) + tqdm_bars[0] = tqdm(range(num_iters)) def _update_tqdm(args: Any, transform: Any) -> None: """Update the tqdm progress bar with the latest objective value.""" @@ -329,10 +348,10 @@ def _update_progress_bar(loss_val: Float[Array, "1"], iter_num: int) -> None: # Conditions for iteration number is_first: bool = iter_num == 0 is_multiple: bool = (iter_num % log_rate == 0) & ( - iter_num != n_iters - remainder + iter_num != num_iters - remainder ) - is_remainder: bool = iter_num == n_iters - remainder - is_last: bool = iter_num == n_iters - 1 + is_remainder: bool = iter_num == num_iters - remainder + is_last: bool = iter_num == num_iters - 1 # Define progress bar, if first iteration _callback(is_first, _define_tqdm, None) diff --git a/gpjax/gps.py b/gpjax/gps.py index 82f0d660..a3b0fd7c 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -18,23 +18,23 @@ import distrax as dx import jax.numpy as jnp -from chex import dataclass, PRNGKey as PRNGKeyType from jaxtyping import Array, Float +from jax.random import KeyArray from jaxlinop import identity from jaxkern.kernels import AbstractKernel +from jaxutils import PyTree from .config import get_global_config from .kernels import AbstractKernel -from .likelihoods import AbstractLikelihood, Conjugate, Gaussian, NonConjugate +from .likelihoods import AbstractLikelihood, Conjugate, NonConjugate from .mean_functions import AbstractMeanFunction, Zero from jaxutils import Dataset from .utils import concat_dictionaries from .gaussian_distribution import GaussianDistribution -@dataclass -class AbstractPrior: +class AbstractPrior(PyTree): """Abstract Gaussian process prior. All Gaussian processes priors should inherit from this class. @@ -79,7 +79,7 @@ def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: raise NotImplementedError @abstractmethod - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """An initialisation method for the GP's parameters. This method should be implemented for all classes that inherit the ``AbstractPrior`` class. Whilst not always necessary, the method accepts a PRNG key to allow @@ -87,7 +87,7 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: through the ``initialise`` function given in GPJax. Args: - key (PRNGKeyType): The PRNG key. + key (KeyArray): The PRNG key. Returns: Dict: The initialised parameter set. @@ -98,7 +98,8 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: ####################### # GP Priors ####################### -@dataclass + + class Prior(AbstractPrior): """A Gaussian process prior object. The GP is parameterised by a `mean `_ @@ -120,17 +121,25 @@ class Prior(AbstractPrior): >>> >>> kernel = gpx.kernels.RBF() >>> prior = gpx.Prior(kernel = kernel) - - Attributes: - kernel (AbstractKernel): The kernel function used to parameterise the prior. - mean_function (MeanFunction): The mean function used to parameterise the - prior. Defaults to zero. - name (str): The name of the GP prior. Defaults to "GP prior". """ - kernel: AbstractKernel - mean_function: Optional[AbstractMeanFunction] = Zero() - name: Optional[str] = "GP prior" + def __init__( + self, + kernel: AbstractKernel, + mean_function: Optional[AbstractMeanFunction] = Zero(), + name: Optional[str] = "GP prior", + ) -> None: + """Initialise the GP prior. + + Args: + kernel (AbstractKernel): The kernel function used to parameterise the prior. + mean_function (Optional[MeanFunction]): The mean function used to parameterise the + prior. Defaults to zero. + name (Optional[str]): The name of the GP prior. Defaults to "GP prior". + """ + self.kernel = kernel + self.mean_function = mean_function + self.name = name def __mul__(self, other: AbstractLikelihood): """The product of a prior and likelihood is proportional to the @@ -230,11 +239,11 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: return predict_fn - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """Initialise the GP prior's parameter set. Args: - key (PRNGKeyType): The PRNG key. + key (KeyArray): The PRNG key. Returns: Dict: The initialised parameter set. @@ -248,7 +257,6 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: ####################### # GP Posteriors ####################### -@dataclass class AbstractPosterior(AbstractPrior): """The base GP posterior object conditioned on an observed dataset. All posterior objects should inherit from this class. @@ -257,17 +265,24 @@ class AbstractPosterior(AbstractPrior): `_. Since dataclasses take over ``__init__``, the ``__post_init__`` method can be used to initialise the GP's parameters. - - Attributes: - prior (Prior): The prior distribution of the GP. - likelihood (AbstractLikelihood): The likelihood distribution of the - observed dataset. - name (str): The name of the GP posterior. Defaults to "GP posterior". """ - prior: Prior - likelihood: AbstractLikelihood - name: Optional[str] = "GP posterior" + def __init__( + self, + prior: AbstractPrior, + likelihood: AbstractLikelihood, + name: Optional[str] = "GP posterior", + ) -> None: + """Initialise the GP posterior object. + + Args: + prior (Prior): The prior distribution of the GP. + likelihood (AbstractLikelihood): The likelihood distribution of the observed dataset. + name (Optional[str]): The name of the GP posterior. Defaults to "GP posterior". + """ + self.prior = prior + self.likelihood = likelihood + self.name = name @abstractmethod def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: @@ -285,11 +300,11 @@ def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: """ raise NotImplementedError - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """Initialise the parameter set of a GP posterior. Args: - key (PRNGKeyType): The PRNG key. + key (KeyArray): The PRNG key. Returns: Dict: The initialised parameter set. @@ -300,7 +315,6 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: ) -@dataclass class ConjugatePosterior(AbstractPosterior): """A Gaussian process posterior distribution when the constituent likelihood function is a Gaussian distribution. In such cases, the latent function values @@ -332,16 +346,24 @@ class ConjugatePosterior(AbstractPosterior): >>> likelihood = gpx.likelihoods.Gaussian() >>> >>> posterior = prior * likelihood - - Attributes: - prior (Prior): The prior distribution of the GP. - likelihood (Gaussian): The Gaussian likelihood distribution of the observed dataset. - name (str): The name of the GP posterior. Defaults to "Conjugate posterior". """ - prior: Prior - likelihood: Gaussian - name: Optional[str] = "Conjugate posterior" + def __init__( + self, + prior: AbstractPrior, + likelihood: AbstractLikelihood, + name: Optional[str] = "GP posterior", + ) -> None: + """Initialise the conjugate GP posterior object. + + Args: + prior (Prior): The prior distribution of the GP. + likelihood (AbstractLikelihood): The likelihood distribution of the observed dataset. + name (Optional[str]): The name of the GP posterior. Defaults to "GP posterior". + """ + self.prior = prior + self.likelihood = likelihood + self.name = name def predict( self, @@ -501,7 +523,7 @@ def marginal_log_likelihood( Args: train_data (Dataset): The training dataset used to compute the marginal log-likelihood. - negative (bool, optional): Whether or not the returned function + negative (Optional[bool]): Whether or not the returned function should be negative. For optimisation, the negative is useful as minimisation of the negative marginal log-likelihood is equivalent to maximisation of the marginal log-likelihood. @@ -560,7 +582,6 @@ def mll( return mll -@dataclass class NonConjugatePosterior(AbstractPosterior): """ A Gaussian process posterior object for models where the likelihood is @@ -571,24 +592,30 @@ class NonConjugatePosterior(AbstractPosterior): hyperparameters and the latent function. Markov chain Monte Carlo, variational inference, or Laplace approximations can then be used to sample from, or optimise an approximation to, the posterior distribution. - - Attributes: - prior (AbstractPrior): The Gaussian process prior distribution. - likelihood (AbstractLikelihood): The likelihood function that - represents the data. - name (str): The name of the posterior object. Defaults to - "Non-conjugate posterior". """ - prior: AbstractPrior - likelihood: AbstractLikelihood - name: Optional[str] = "Non-conjugate posterior" + def __init__( + self, + prior: AbstractPrior, + likelihood: AbstractLikelihood, + name: Optional[str] = "GP posterior", + ) -> None: + """Initialise a non-conjugate Gaussian process posterior object. + + Args: + prior (AbstractPrior): The Gaussian process prior distribution. + likelihood (AbstractLikelihood): The likelihood function that represents the data. + name (Optional[str]): The name of the posterior object. Defaults to "GP posterior". + """ + self.prior = prior + self.likelihood = likelihood + self.name = name - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """Initialise the parameter set of a non-conjugate GP posterior. Args: - key (PRNGKeyType): A PRNG key used to initialise the parameters. + key (KeyArray): A PRNG key used to initialise the parameters. Returns: Dict: A dictionary containing the default parameter set. @@ -620,7 +647,7 @@ def predict( and output data used for training dataset. Returns: - tp.Callable[[Array], dx.Distribution]: A function that accepts an + Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a ``dx.Distribution``. """ @@ -697,7 +724,7 @@ def marginal_log_likelihood( Args: train_data (Dataset): The training dataset used to compute the marginal log-likelihood. - negative (bool, optional): Whether or not the returned function + negative (Optional[bool]): Whether or not the returned function should be negative. For optimisation, the negative is useful as minimisation of the negative marginal log-likelihood is equivalent to maximisation of the marginal log-likelihood. Defaults to False. diff --git a/gpjax/kernels.py b/gpjax/kernels.py index 7800d1b1..460c5461 100644 --- a/gpjax/kernels.py +++ b/gpjax/kernels.py @@ -1001,9 +1001,9 @@ def cross_covariance( return matrix -@deprecation.deprecated( - deprecated_in="0.5.5", removed_in="0.6.0", details="Use JaxKern for the GraphKernel" -) +# @deprecation.deprecated( +# deprecated_in="0.5.5", removed_in="0.6.0", details="Use JaxKern for the GraphKernel" +# ) class GraphKernel(AbstractKernel): def __init__( self, diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index 6b39c1d2..898099f6 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -16,22 +16,28 @@ import abc from typing import Any, Callable, Dict, Optional from jaxlinop.utils import to_dense +from jaxutils import PyTree import distrax as dx import jax.numpy as jnp import jax.scipy as jsp -from chex import dataclass from jaxtyping import Array, Float from jax.random import KeyArray -@dataclass -class AbstractLikelihood: +class AbstractLikelihood(PyTree): """Abstract base class for likelihoods.""" - num_datapoints: int # The number of datapoints that the likelihood factorises over. - name: Optional[str] = "Likelihood" + def __init__(self, num_datapoints: int, name: Optional[str] = None): + """Initialise the likelihood. + + Args: + num_datapoints (int): The number of datapoints that the likelihood factorises over. + name (Optional[str]): The name of the likelihood. Defaults to None. + """ + self.num_datapoints = num_datapoints + self.name = name def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: """Evaluate the likelihood function at a given predictive distribution. @@ -81,22 +87,28 @@ def link_function(self) -> Callable: raise NotImplementedError -@dataclass class Conjugate: """An abstract class for conjugate likelihoods with respect to a Gaussain process prior.""" -@dataclass class NonConjugate: """An abstract class for non-conjugate likelihoods with respect to a Gaussain process prior.""" -# TODO: revamp this will covariance operators. -@dataclass +# TODO: revamp this with covariance operators. + + class Gaussian(AbstractLikelihood, Conjugate): """Gaussian likelihood object.""" - name: Optional[str] = "Gaussian" + def __init__(self, num_datapoints: int, name: Optional[str] = "Gaussian"): + """Initialise the Gaussian likelihood. + + Args: + num_datapoints (int): The number of datapoints that the likelihood factorises over. + name (Optional[str]): The name of the likelihood. Defaults to "Gaussian". + """ + super().__init__(num_datapoints, name) def _initialise_params(self, key: KeyArray) -> Dict: """Return the variance parameter of the likelihood function. @@ -157,9 +169,15 @@ def predict(self, params: Dict, dist: dx.MultivariateNormalTri) -> dx.Distributi return dx.MultivariateNormalFullCovariance(dist.mean(), noisy_cov) -@dataclass class Bernoulli(AbstractLikelihood, NonConjugate): - name: Optional[str] = "Bernoulli" + def __init__(self, num_datapoints: int, name: Optional[str] = "Bernoulli"): + """Initialise the Bernoulli likelihood. + + Args: + num_datapoints (int): The number of datapoints that the likelihood factorises over. + name (Optional[str]): The name of the likelihood. Defaults to "Bernoulli". + """ + super().__init__(num_datapoints, name) def _initialise_params(self, key: KeyArray) -> Dict: """Initialise the parameter set of a Bernoulli likelihood. diff --git a/gpjax/mean_functions.py b/gpjax/mean_functions.py index 29ec93b1..65902640 100644 --- a/gpjax/mean_functions.py +++ b/gpjax/mean_functions.py @@ -17,16 +17,25 @@ from typing import Dict, Optional import jax.numpy as jnp -from chex import dataclass, PRNGKey as PRNGKeyType +from jax.random import KeyArray from jaxtyping import Array, Float +from jaxutils import PyTree -@dataclass(repr=False) -class AbstractMeanFunction: +class AbstractMeanFunction(PyTree): """Abstract mean function that is used to parameterise the Gaussian process.""" - output_dim: Optional[int] = 1 - name: Optional[str] = "Mean function" + def __init__( + self, output_dim: Optional[int] = 1, name: Optional[str] = "Mean function" + ): + """Initialise the mean function. + + Args: + output_dim (Optional[int]): The output dimension of the mean function. Defaults to 1. + name (Optional[str]): The name of the mean function. Defaults to "Mean function". + """ + self.output_dim = output_dim + self.name = name @abc.abstractmethod def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: @@ -42,11 +51,11 @@ def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: raise NotImplementedError @abc.abstractmethod - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """Return the parameters of the mean function. This method is required for all subclasses. Args: - key (PRNGKeyType): The PRNG key to use for initialising the parameters. + key (KeyArray): The PRNG key to use for initialising the parameters. Returns: Dict: The parameters of the mean function. @@ -54,14 +63,21 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: raise NotImplementedError -@dataclass(repr=False) class Zero(AbstractMeanFunction): """ A zero mean function. This function returns zero for all inputs. """ - output_dim: Optional[int] = 1 - name: Optional[str] = "Zero mean function" + def __init__( + self, output_dim: Optional[int] = 1, name: Optional[str] = "Mean function" + ): + """Initialise the zero-mean function. + + Args: + output_dim (Optional[int]): The output dimension of the mean function. Defaults to 1. + name (Optional[str]): The name of the mean function. Defaults to "Mean function". + """ + super().__init__(output_dim, name) def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: """Evaluate the mean function at the given points. @@ -76,11 +92,11 @@ def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: out_shape = (x.shape[0], self.output_dim) return jnp.zeros(shape=out_shape) - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """The parameters of the mean function. For the zero-mean function, this is an empty dictionary. Args: - key (PRNGKeyType): The PRNG key to use for initialising the parameters. + key (KeyArray): The PRNG key to use for initialising the parameters. Returns: Dict: The parameters of the mean function. @@ -88,15 +104,22 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: return {} -@dataclass(repr=False) class Constant(AbstractMeanFunction): """ A zero mean function. This function returns a repeated scalar value for all inputs. The scalar value itself can be treated as a model hyperparameter and learned during training. """ - output_dim: Optional[int] = 1 - name: Optional[str] = "Constant mean function" + def __init__( + self, output_dim: Optional[int] = 1, name: Optional[str] = "Mean function" + ): + """Initialise the constant-mean function. + + Args: + output_dim (Optional[int]): The output dimension of the mean function. Defaults to 1. + name (Optional[str]): The name of the mean function. Defaults to "Mean function". + """ + super().__init__(output_dim, name) def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: """Evaluate the mean function at the given points. @@ -111,11 +134,11 @@ def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: out_shape = (x.shape[0], self.output_dim) return jnp.ones(shape=out_shape) * params["constant"] - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """The parameters of the mean function. For the constant-mean function, this is a dictionary with a single value. Args: - key (PRNGKeyType): The PRNG key to use for initialising the parameters. + key (KeyArray): The PRNG key to use for initialising the parameters. Returns: Dict: The parameters of the mean function. diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 85b01f19..f409905d 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -22,8 +22,9 @@ import jax import jax.numpy as jnp import jax.random as jr -from chex import dataclass, PRNGKey as PRNGKeyType +from jax.random import KeyArray from jaxtyping import Array, Float +from jaxutils import PyTree from .config import Identity, get_global_config from .utils import merge_dictionaries @@ -32,17 +33,17 @@ ################################ # Base operations ################################ -@dataclass -class ParameterState: +class ParameterState(PyTree): """ The state of the model. This includes the parameter set, which parameters are to be trained and bijectors that allow parameters to be constrained and unconstrained. """ - params: Dict - trainables: Dict - bijectors: Dict + def __init__(self, params: Dict, trainables: Dict, bijectors: Dict) -> None: + self.params = params + self.trainables = trainables + self.bijectors = bijectors def unpack(self): """Unpack the state into a tuple of parameters, trainables and bijectors. @@ -53,7 +54,7 @@ def unpack(self): return self.params, self.trainables, self.bijectors -def initialise(model, key: PRNGKeyType = None, **kwargs) -> ParameterState: +def initialise(model, key: KeyArray = None, **kwargs) -> ParameterState: """ Initialise the stateful parameters of any GPJax object. This function also returns the trainability status of each parameter and set of bijectors that @@ -61,7 +62,7 @@ def initialise(model, key: PRNGKeyType = None, **kwargs) -> ParameterState: Args: model: The GPJax object that is to be initialised. - key (PRNGKeyType, optional): The random key that is to be used for + key (KeyArray, optional): The random key that is to be used for initialisation. Defaults to None. Returns: diff --git a/gpjax/test_variational_inference.py b/gpjax/test_variational_inference.py new file mode 100644 index 00000000..1e7eb9eb --- /dev/null +++ b/gpjax/test_variational_inference.py @@ -0,0 +1,159 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import typing as tp + +import jax +import jax.numpy as jnp +import jax.random as jr +import pytest +from jax.config import config + +import gpjax as gpx +from gpjax.variational_families import ( + CollapsedVariationalGaussian, + ExpectationVariationalGaussian, + NaturalVariationalGaussian, + VariationalGaussian, + WhitenedVariationalGaussian, +) + +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) + + +def test_abstract_variational_inference(): + prior = gpx.Prior(kernel=gpx.RBF()) + lik = gpx.Gaussian(num_datapoints=20) + post = prior * lik + n_inducing_points = 10 + inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing_points).reshape(-1, 1) + vartiational_family = gpx.VariationalGaussian( + prior=prior, inducing_inputs=inducing_inputs + ) + + with pytest.raises(TypeError): + gpx.variational_inference.AbstractVariationalInference( + posterior=post, vartiational_family=vartiational_family + ) + + +def get_data_and_gp(n_datapoints, point_dim): + x = jnp.linspace(-5.0, 5.0, n_datapoints).reshape(-1, 1) + y = jnp.sin(x) + jr.normal(key=jr.PRNGKey(123), shape=x.shape) * 0.1 + x = jnp.hstack([x] * point_dim) + D = gpx.Dataset(X=x, y=y) + + p = gpx.Prior(kernel=gpx.RBF()) + lik = gpx.Gaussian(num_datapoints=n_datapoints) + post = p * lik + return D, post, p + + +@pytest.mark.parametrize("n_datapoints, n_inducing_points", [(10, 2), (100, 10)]) +@pytest.mark.parametrize("jit_fns", [False, True]) +@pytest.mark.parametrize("point_dim", [1, 2, 3]) +@pytest.mark.parametrize( + "variational_family", + [ + VariationalGaussian, + WhitenedVariationalGaussian, + NaturalVariationalGaussian, + ExpectationVariationalGaussian, + ], +) +def test_stochastic_vi( + n_datapoints, n_inducing_points, jit_fns, point_dim, variational_family +): + D, post, prior = get_data_and_gp(n_datapoints, point_dim) + inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing_points).reshape(-1, 1) + inducing_inputs = jnp.hstack([inducing_inputs] * point_dim) + + q = variational_family(prior=prior, inducing_inputs=inducing_inputs) + + svgp = gpx.StochasticVI(posterior=post, variational_family=q) + assert svgp.posterior.prior == post.prior + assert svgp.posterior.likelihood == post.likelihood + + params, _, _ = gpx.initialise(svgp, jr.PRNGKey(123)).unpack() + + assert svgp.prior == post.prior + assert svgp.likelihood == post.likelihood + + if jit_fns: + elbo_fn = jax.jit(svgp.elbo(D)) + else: + elbo_fn = svgp.elbo(D) + assert isinstance(elbo_fn, tp.Callable) + elbo_value = elbo_fn(params, D) + assert isinstance(elbo_value, jnp.ndarray) + + # Test gradients + grads = jax.grad(elbo_fn, argnums=0)(params, D) + assert isinstance(grads, tp.Dict) + assert len(grads) == len(params) + + +@pytest.mark.parametrize("n_datapoints, n_inducing_points", [(10, 2), (100, 10)]) +@pytest.mark.parametrize("jit_fns", [False, True]) +@pytest.mark.parametrize("point_dim", [1, 2]) +def test_collapsed_vi(n_datapoints, n_inducing_points, jit_fns, point_dim): + D, post, prior = get_data_and_gp(n_datapoints, point_dim) + likelihood = gpx.Gaussian(num_datapoints=n_datapoints) + + inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing_points).reshape(-1, 1) + inducing_inputs = jnp.hstack([inducing_inputs] * point_dim) + + q = CollapsedVariationalGaussian( + prior=prior, likelihood=likelihood, inducing_inputs=inducing_inputs + ) + + sgpr = gpx.variational_inference.CollapsedVI(posterior=post, variational_family=q) + assert sgpr.posterior.prior == post.prior + assert sgpr.posterior.likelihood == post.likelihood + + params, _, _ = gpx.initialise(sgpr, jr.PRNGKey(123)).unpack() + + assert sgpr.prior == post.prior + assert sgpr.likelihood == post.likelihood + + if jit_fns: + elbo_fn = jax.jit(sgpr.elbo(D)) + else: + elbo_fn = sgpr.elbo(D) + assert isinstance(elbo_fn, tp.Callable) + elbo_value = elbo_fn(params) + assert isinstance(elbo_value, jnp.ndarray) + + # Test gradients + grads = jax.grad(elbo_fn)(params) + assert isinstance(grads, tp.Dict) + assert len(grads) == len(params) + + # We should raise an error for non-Collapsed variational families: + with pytest.raises(TypeError): + q = gpx.variational_families.VariationalGaussian( + prior=prior, inducing_inputs=inducing_inputs + ) + gpx.variational_inference.CollapsedVI(posterior=post, variational_family=q) + + # We should raise an error for non-Gaussian likelihoods: + with pytest.raises(TypeError): + q = gpx.variational_families.CollapsedVariationalGaussian( + prior=prior, likelihood=likelihood, inducing_inputs=inducing_inputs + ) + gpx.variational_inference.CollapsedVI( + posterior=prior * gpx.Bernoulli(num_datapoints=D.n), variational_family=q + ) diff --git a/gpjax/types.py b/gpjax/types.py index b3fedb97..d1e8b110 100644 --- a/gpjax/types.py +++ b/gpjax/types.py @@ -13,67 +13,20 @@ # limitations under the License. # ============================================================================== -import jax.numpy as jnp -from chex import dataclass -from jaxtyping import Array, Float +import jaxutils import deprecation -NoneType = type(None) - - -@deprecation.deprecated( +Dataset = deprecation.deprecated( deprecated_in="0.5.5", removed_in="0.6.0", details="Use JaxUtils for a Dataset object", -) -@dataclass -class Dataset: - """GPJax Dataset class.""" - - X: Float[Array, "N D"] - y: Float[Array, "N 1"] = None - - def __repr__(self) -> str: - return ( - f"- Number of datapoints: {self.X.shape[0]}\n- Dimension:" - f" {self.X.shape[1]}" - ) - - def __add__(self, other: "Dataset") -> "Dataset": - """Combines two datasets into one. The right-hand dataset is stacked beneath left.""" - x = jnp.concatenate((self.X, other.X)) - y = jnp.concatenate((self.y, other.y)) - - return Dataset(X=x, y=y) +)(jaxutils.Dataset) - @property - def n(self) -> int: - """The number of observations in the dataset.""" - return self.X.shape[0] - - @property - def in_dim(self) -> int: - """The dimension of the input data.""" - return self.X.shape[1] - - @property - def out_dim(self) -> int: - """The dimension of the output data.""" - return self.y.shape[1] +verify_dataset = deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxUtils for a Dataset object", +)(jaxutils.verify_dataset) -def verify_dataset(ds: Dataset) -> None: - """Apply a series of checks to the dataset to ensure that downstream operations are safe.""" - assert ds.X.ndim == 2, ( - "2-dimensional training inputs are required. Current dimension:" - f" {ds.X.ndim}." - ) - if ds.y is not None: - assert ds.y.ndim == 2, ( - "2-dimensional training outputs are required. Current dimension:" - f" {ds.y.ndim}." - ) - assert ds.X.shape[0] == ds.y.shape[0], ( - "Number of inputs must equal the number of outputs. \nCurrent" - f" counts:\n- X: {ds.X.shape[0]}\n- y: {ds.y.shape[0]}" - ) +__all__ = ["Dataset" "verify_dataset"] diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index e9c6ecd0..ac7e6ccc 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -19,11 +19,11 @@ import distrax as dx import jax.numpy as jnp import jax.scipy as jsp -from chex import dataclass, PRNGKey as PRNGKeyType +from jax.random import KeyArray from jaxtyping import Array, Float from jaxlinop import identity -from jaxutils import Dataset +from jaxutils import PyTree, Dataset import jaxlinop as jlo from .config import get_global_config @@ -33,8 +33,7 @@ from .gaussian_distribution import GaussianDistribution -@dataclass -class AbstractVariationalFamily: +class AbstractVariationalFamily(PyTree): """ Abstract base class used to represent families of distributions that can be used within variational inference. @@ -55,13 +54,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> GaussianDistribution: return self.predict(*args, **kwargs) @abc.abstractmethod - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """ The parameters of the distribution. For example, the multivariate Gaussian would return a mean vector and covariance matrix. Args: - key (PRNGKeyType): The PRNG key used to initialise the parameters. + key (KeyArray): The PRNG key used to initialise the parameters. Returns: Dict: The parameters of the distribution. @@ -84,20 +83,27 @@ def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: raise NotImplementedError -@dataclass class AbstractVariationalGaussian(AbstractVariationalFamily): """The variational Gaussian family of probability distributions.""" - prior: Prior - inducing_inputs: Float[Array, "N D"] - name: str = "Gaussian" - - def __post_init__(self): - """Initialise the variational Gaussian distribution.""" + def __init__( + self, + prior: Prior, + inducing_inputs: Float[Array, "N D"], + name: Optional[str] = "Variational Gaussian", + ) -> None: + """ + Args: + prior (Prior): The prior distribution. + inducing_inputs (Float[Array, "N D"]): The inducing inputs. + name (Optional[str]): The name of the variational family. Defaults to "Gaussian". + """ + self.prior = prior + self.inducing_inputs = inducing_inputs self.num_inducing = self.inducing_inputs.shape[0] + self.name = name -@dataclass class VariationalGaussian(AbstractVariationalGaussian): """The variational Gaussian family of probability distributions. @@ -108,14 +114,14 @@ class VariationalGaussian(AbstractVariationalGaussian): :math:`\\mu` and sqrt with S = sqrt sqrtᵀ. """ - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """ Return the variational mean vector, variational root covariance matrix, and inducing input vector that parameterise the variational Gaussian distribution. Args: - key (PRNGKeyType): The PRNG key used to initialise the parameters. + key (KeyArray): The PRNG key used to initialise the parameters. Returns: Dict: The parameters of the distribution. @@ -250,7 +256,6 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: return predict_fn -@dataclass class WhitenedVariationalGaussian(VariationalGaussian): """ The whitened variational Gaussian family of probability distributions. @@ -262,7 +267,21 @@ class WhitenedVariationalGaussian(VariationalGaussian): """ - name: str = "Whitened variational Gaussian" + def __init__( + self, + prior: Prior, + inducing_inputs: Float[Array, "N D"], + name: Optional[str] = "Whitened variational Gaussian", + ) -> None: + """Initialise the whitened variational Gaussian family. + + Args: + prior (Prior): The GP prior. + inducing_inputs (Float[Array, "N D"]): The inducing inputs. + name (Optional[str]): The name of the variational family. + """ + + super().__init__(prior, inducing_inputs, name) def prior_kl(self, params: Dict) -> Float[Array, "1"]: """Compute the KL-divergence between our variational approximation and @@ -355,7 +374,6 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: return predict_fn -@dataclass class NaturalVariationalGaussian(AbstractVariationalGaussian): """The natural variational Gaussian family of probability distributions. @@ -363,12 +381,25 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian): and the distribution over the inducing inputs is q(u) = N(μ, S). Expressing the variational distribution, in the form of the exponential family, q(u) = exp(θᵀ T(u) - a(θ)), gives rise to the natural paramerisation θ = (θ₁, θ₂) = (S⁻¹μ, -S⁻¹/2), to perform model inference, where T(u) = [u, uuᵀ] are the sufficient statistics. - """ - name: str = "Natural Gaussian" + def __init__( + self, + prior: Prior, + inducing_inputs: Float[Array, "N D"], + name: Optional[str] = "Natural variational Gaussian", + ) -> None: + """Initialise the natural variational Gaussian family. - def _initialise_params(self, key: PRNGKeyType) -> Dict: + Args: + prior (Prior): The GP prior. + inducing_inputs (Float[Array, "N D"]): The inducing inputs. + name (Optional[str]): The name of the variational family. + """ + + super().__init__(prior, inducing_inputs, name) + + def _initialise_params(self, key: KeyArray) -> Dict: """Return the natural vector and matrix, inducing inputs, and hyperparameters that parameterise the natural Gaussian distribution.""" m = self.num_inducing @@ -527,7 +558,6 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri: return predict_fn -@dataclass class ExpectationVariationalGaussian(AbstractVariationalGaussian): """The natural variational Gaussian family of probability distributions. @@ -538,9 +568,23 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian): η = (η₁, η₁) = (μ, S + uuᵀ) to perform model inference over. """ - name: str = "Expectation Gaussian" + def __init__( + self, + prior: Prior, + inducing_inputs: Float[Array, "N D"], + name: Optional[str] = "Expectation variational Gaussian", + ) -> None: + """Initialise the expectation variational Gaussian family. + + Args: + prior (Prior): The GP prior. + inducing_inputs (Float[Array, "N D"]): The inducing inputs. + name (Optional[str]): The name of the variational family. + """ + + super().__init__(prior, inducing_inputs, name) - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """Return the expectation vector and matrix, inducing inputs, and hyperparameters that parameterise the expectation Gaussian distribution.""" self.num_inducing = self.inducing_inputs.shape[0] @@ -691,25 +735,36 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: return predict_fn -@dataclass class CollapsedVariationalGaussian(AbstractVariationalFamily): """Collapsed variational Gaussian family of probability distributions. The key reference is Titsias, (2009) - Variational Learning of Inducing Variables in Sparse Gaussian Processes.""" - prior: Prior - likelihood: AbstractLikelihood - inducing_inputs: Float[Array, "M D"] - name: str = "Collapsed variational Gaussian" - diag: Optional[bool] = False + def __init__( + self, + prior: Prior, + likelihood: AbstractLikelihood, + inducing_inputs: Float[Array, "M D"], + name: str = "Collapsed variational Gaussian", + ): + """Initialise the collapsed variational Gaussian family of probability distributions. - def __post_init__(self): - """Initialise the variational Gaussian distribution.""" - self.num_inducing = self.inducing_inputs.shape[0] + Args: + prior (Prior): The prior distribution that we are approximating. + likelihood (AbstractLikelihood): The likelihood function that we are using to model the data. + inducing_inputs (Float[Array, "M D"]): The inducing inputs that are to be used to parameterise the variational Gaussian distribution. + name (str, optional): The name of the variational family. Defaults to "Collapsed variational Gaussian". + """ - if not isinstance(self.likelihood, Gaussian): + if not isinstance(likelihood, Gaussian): raise TypeError("Likelihood must be Gaussian.") - def _initialise_params(self, key: PRNGKeyType) -> Dict: + self.prior = prior + self.likelihood = likelihood + self.inducing_inputs = inducing_inputs + self.num_inducing = self.inducing_inputs.shape[0] + self.name = name + + def _initialise_params(self, key: KeyArray) -> Dict: """Return the variational mean vector, variational root covariance matrix, and inducing input vector that parameterise the variational Gaussian distribution.""" return concat_dictionaries( self.prior._initialise_params(key), diff --git a/gpjax/variational_inference.py b/gpjax/variational_inference.py index 34f493d1..6745656e 100644 --- a/gpjax/variational_inference.py +++ b/gpjax/variational_inference.py @@ -18,11 +18,12 @@ import jax.numpy as jnp import jax.scipy as jsp -from chex import dataclass, PRNGKey as PRNGKeyType from jax import vmap from jaxtyping import Array, Float from jaxlinop import identity +from jax.random import KeyArray +from jaxutils import PyTree from .config import get_global_config from .gps import AbstractPosterior @@ -36,18 +37,26 @@ ) -@dataclass -class AbstractVariationalInference: +class AbstractVariationalInference(PyTree): """A base class for inference and training of variational families against an extact posterior""" - posterior: AbstractPosterior - variational_family: AbstractVariationalFamily + def __init__( + self, + posterior: AbstractPosterior, + variational_family: AbstractVariationalFamily, + ) -> None: + """Initialise the variational inference module. - def __post_init__(self): + Args: + posterior (AbstractPosterior): The exact posterior distribution. + variational_family (AbstractVariationalFamily): The variational family to be trained. + """ + self.posterior = posterior self.prior = self.posterior.prior self.likelihood = self.posterior.likelihood + self.variational_family = variational_family - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """Construct the parameter set used within the variational scheme adopted.""" hyperparams = concat_dictionaries( {"likelihood": self.posterior.likelihood._initialise_params(key)}, @@ -71,15 +80,9 @@ def elbo( raise NotImplementedError -@dataclass class StochasticVI(AbstractVariationalInference): """Stochastic Variational inference training module. The key reference is Hensman et. al., (2013) - Gaussian processes for big data.""" - def __post_init__(self): - self.prior = self.posterior.prior - self.likelihood = self.posterior.likelihood - self.num_inducing = self.variational_family.num_inducing - def elbo( self, train_data: Dataset, negative: bool = False ) -> Callable[[Float[Array, "N D"]], Float[Array, "1"]]: @@ -144,22 +147,30 @@ def q_moments(x): return expectation -@dataclass class CollapsedVI(AbstractVariationalInference): """Collapsed variational inference for a sparse Gaussian process regression model. The key reference is Titsias, (2009) - Variational Learning of Inducing Variables in Sparse Gaussian Processes.""" - def __post_init__(self): - self.prior = self.posterior.prior - self.likelihood = self.posterior.likelihood - self.num_inducing = self.variational_family.num_inducing + def __init__( + self, + posterior: AbstractPosterior, + variational_family: AbstractVariationalFamily, + ) -> None: + """Initialise the variational inference module. + + Args: + posterior (AbstractPosterior): The exact posterior distribution. + variational_family (AbstractVariationalFamily): The variational family to be trained. + """ - if not isinstance(self.likelihood, Gaussian): + if not isinstance(posterior.likelihood, Gaussian): raise TypeError("Likelihood must be Gaussian.") - if not isinstance(self.variational_family, CollapsedVariationalGaussian): + if not isinstance(variational_family, CollapsedVariationalGaussian): raise TypeError("Variational family must be CollapsedVariationalGaussian.") + super().__init__(posterior, variational_family) + def elbo( self, train_data: Dataset, negative: bool = False ) -> Callable[[Dict], Float[Array, "1"]]: @@ -180,7 +191,7 @@ def elbo( mean_function = self.prior.mean_function kernel = self.prior.kernel - m = self.num_inducing + m = self.variational_family.num_inducing jitter = get_global_config()["jitter"] # Constant for whether or not to negate the elbo for optimisation purposes diff --git a/requirements/dev.txt b/requirements/dev.txt new file mode 100644 index 00000000..b6081cc7 --- /dev/null +++ b/requirements/dev.txt @@ -0,0 +1,7 @@ +black +isort +pylint +flake8 +pytest +networkx +pytest-cov \ No newline at end of file diff --git a/requirements/requirements.txt b/requirements/requirements.txt new file mode 100644 index 00000000..a8b062aa --- /dev/null +++ b/requirements/requirements.txt @@ -0,0 +1,11 @@ +jax>=0.4.1 +jaxlib>=0.4.1 +optax +jaxutils +jaxkern +distrax>=0.1.2 +tqdm>=4.0.0 +ml-collections==0.1.0 +jaxtyping>=0.0.2 +jaxlinop>=0.0.3 +deprecation \ No newline at end of file diff --git a/setup.py b/setup.py index 53c5d678..d98ddd57 100644 --- a/setup.py +++ b/setup.py @@ -22,21 +22,6 @@ def get_versions(): return versions -REQUIRES = [ - "jax>=0.4.1", - "jaxlib>=0.4.1", - "optax", - "jaxutils", - "jaxkern", - "chex", - "distrax>=0.1.2", - "tqdm>=4.0.0", - "ml-collections==0.1.0", - "jaxtyping>=0.0.2", - "jaxlinop>=0.0.3", - "deprecation", -] - EXTRAS = { "dev": [ "black", @@ -66,7 +51,20 @@ def get_versions(): "Documentation": "https://gpjax.readthedocs.io/en/latest/", "Source": "https://github.com/thomaspinder/GPJax", }, - install_requires=REQUIRES, + python_requires=">=3.7", + install_requires=[ + "jax>=0.4.1", + "jaxlib>=0.4.1", + "optax", + "jaxutils>=0.0.6", + "jaxkern", + "distrax>=0.1.2", + "tqdm>=4.0.0", + "ml-collections==0.1.0", + "jaxtyping>=0.0.2", + "jaxlinop>=0.0.3", + "deprecation", + ], tests_require=EXTRAS["dev"], extras_require=EXTRAS, keywords=["gaussian-processes jax machine-learning bayesian"], diff --git a/tests/test_abstractions.py b/tests/test_abstractions.py index c9ba5c38..4b7fce30 100644 --- a/tests/test_abstractions.py +++ b/tests/test_abstractions.py @@ -28,10 +28,10 @@ config.update("jax_enable_x64", True) -@pytest.mark.parametrize("n_iters", [1, 5]) +@pytest.mark.parametrize("num_iters", [1, 5]) @pytest.mark.parametrize("n", [1, 20]) @pytest.mark.parametrize("verbose", [True, False]) -def test_fit(n_iters, n, verbose): +def test_fit(num_iters, n, verbose): key = jr.PRNGKey(123) x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(n, 1)), axis=0) y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 @@ -41,13 +41,13 @@ def test_fit(n_iters, n, verbose): mll = p.marginal_log_likelihood(D, negative=True) pre_mll_val = mll(parameter_state.params) optimiser = optax.adam(learning_rate=0.1) - inference_state = fit(mll, parameter_state, optimiser, n_iters, verbose=verbose) + inference_state = fit(mll, parameter_state, optimiser, num_iters, verbose=verbose) optimised_params, history = inference_state.unpack() assert isinstance(inference_state, InferenceState) assert isinstance(optimised_params, dict) assert mll(optimised_params) < pre_mll_val assert isinstance(history, jnp.ndarray) - assert history.shape[0] == n_iters + assert history.shape[0] == num_iters def test_stop_grads(): @@ -59,18 +59,18 @@ def test_stop_grads(): parameter_state = ParameterState( params=params, trainables=trainables, bijectors=bijectors ) - inference_state = fit(loss_fn, parameter_state, optimiser, n_iters=1) + inference_state = fit(loss_fn, parameter_state, optimiser, num_iters=1) learned_params = inference_state.params assert isinstance(inference_state, InferenceState) assert learned_params["y"] == params["y"] assert learned_params["x"] != params["x"] -@pytest.mark.parametrize("n_iters", [1, 5]) +@pytest.mark.parametrize("num_iters", [1, 5]) @pytest.mark.parametrize("nb", [1, 20, 50]) @pytest.mark.parametrize("ndata", [50]) @pytest.mark.parametrize("verbose", [True, False]) -def test_batch_fitting(n_iters, nb, ndata, verbose): +def test_batch_fitting(num_iters, nb, ndata, verbose): key = jr.PRNGKey(123) x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(ndata, 1)), axis=0) y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 @@ -93,21 +93,21 @@ def test_batch_fitting(n_iters, nb, ndata, verbose): optimiser = optax.adam(learning_rate=0.1) key = jr.PRNGKey(42) inference_state = fit_batches( - objective, parameter_state, D, optimiser, key, nb, n_iters, verbose=verbose + objective, parameter_state, D, optimiser, key, nb, num_iters, verbose=verbose ) optimised_params, history = inference_state.unpack() assert isinstance(inference_state, InferenceState) assert isinstance(optimised_params, dict) assert objective(optimised_params, D) < pre_mll_val assert isinstance(history, jnp.ndarray) - assert history.shape[0] == n_iters + assert history.shape[0] == num_iters -@pytest.mark.parametrize("n_iters", [1, 5]) +@pytest.mark.parametrize("num_iters", [1, 5]) @pytest.mark.parametrize("nb", [1, 20, 50]) @pytest.mark.parametrize("ndata", [50]) @pytest.mark.parametrize("verbose", [True, False]) -def test_natural_gradients(ndata, nb, n_iters, verbose): +def test_natural_gradients(ndata, nb, num_iters, verbose): key = jr.PRNGKey(123) x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(ndata, 1)), axis=0) y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 @@ -140,7 +140,7 @@ def test_natural_gradients(ndata, nb, n_iters, verbose): hyper_optimiser, key, nb, - n_iters, + num_iters, verbose=verbose, ) optimised_params, history = inference_state.unpack() @@ -148,7 +148,7 @@ def test_natural_gradients(ndata, nb, n_iters, verbose): assert isinstance(optimised_params, dict) assert objective(optimised_params, D) < pre_mll_val assert isinstance(history, jnp.ndarray) - assert history.shape[0] == n_iters + assert history.shape[0] == num_iters @pytest.mark.parametrize("batch_size", [1, 2, 50]) diff --git a/tests/test_gaussian_distribution.py b/tests/test_gaussian_distribution.py index 8f3c0b63..46a67fb1 100644 --- a/tests/test_gaussian_distribution.py +++ b/tests/test_gaussian_distribution.py @@ -72,6 +72,9 @@ def test_diag_linear_operator(n: int) -> None: distrax_dist = MultivariateNormalDiag(loc=mean, scale_diag=diag) assert approx_equal(dist_diag.mean(), distrax_dist.mean()) + assert approx_equal(dist_diag.mode(), distrax_dist.mode()) + assert approx_equal(dist_diag.median(), distrax_dist.median()) + assert approx_equal(dist_diag.entropy(), distrax_dist.entropy()) assert approx_equal(dist_diag.variance(), distrax_dist.variance()) assert approx_equal(dist_diag.stddev(), distrax_dist.stddev()) assert approx_equal(dist_diag.covariance(), distrax_dist.covariance()) @@ -104,6 +107,9 @@ def test_dense_linear_operator(n: int) -> None: ) assert approx_equal(dist_dense.mean(), distrax_dist.mean()) + assert approx_equal(dist_dense.mode(), distrax_dist.mode()) + assert approx_equal(dist_dense.median(), distrax_dist.median()) + assert approx_equal(dist_dense.entropy(), distrax_dist.entropy()) assert approx_equal(dist_dense.variance(), distrax_dist.variance()) assert approx_equal(dist_dense.stddev(), distrax_dist.stddev()) assert approx_equal(dist_dense.covariance(), distrax_dist.covariance()) @@ -142,3 +148,7 @@ def test_kl_divergence(n: int) -> None: assert approx_equal( dist_a.kl_divergence(dist_b), distrax_dist_a.kl_divergence(distrax_dist_b) ) + + with pytest.raises(ValueError): + incompatible = GaussianDistribution(loc=jnp.ones((2 * n,))) + incompatible.kl_divergence(dist_a) diff --git a/tests/test_gps.py b/tests/test_gps.py index d7a4246d..9bab15d8 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -25,6 +25,7 @@ from gpjax import Dataset, initialise from gpjax.gps import ( AbstractPrior, + AbstractPosterior, ConjugatePosterior, NonConjugatePosterior, Prior, @@ -166,6 +167,24 @@ def test_param_construction(num_datapoints, lik): ] +@pytest.mark.parametrize("lik", [Bernoulli, Gaussian]) +def test_abstract_posterior(lik): + pr = Prior(kernel=RBF()) + likelihood = lik(num_datapoints=10) + + with pytest.raises(TypeError): + _ = AbstractPosterior(pr, likelihood) + + class DummyPosterior(AbstractPosterior): + def predict(self): + pass + + dummy_post = DummyPosterior(pr, likelihood) + assert isinstance(dummy_post, AbstractPosterior) + assert dummy_post.likelihood == likelihood + assert dummy_post.prior == pr + + @pytest.mark.parametrize("lik", [Bernoulli, Gaussian]) def test_posterior_construct(lik): pr = Prior(kernel=RBF()) diff --git a/tests/test_kernels.py b/tests/test_kernels.py index bb7bafe4..df4c8aa0 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -22,33 +22,29 @@ import jax.random as jr import networkx as nx import pytest +from gpjax.parameters import initialise from jax.config import config +from jax.random import KeyArray as PRNGKeyType +from jaxlinop import LinearOperator, identity from jaxtyping import Array, Float -from chex import PRNGKey as PRNGKeyType - -from jaxlinop import ( - LinearOperator, - identity, -) from gpjax.kernels import ( RBF, - Linear, - RationalQuadratic, + AbstractKernel, CombinationKernel, GraphKernel, - AbstractKernel, + Linear, Matern12, Matern32, Matern52, + Periodic, Polynomial, PoweredExponential, ProductKernel, - Periodic, + RationalQuadratic, SumKernel, euclidean_distance, ) -from gpjax.parameters import initialise # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -559,6 +555,7 @@ def test_graph_kernel(): # Create graph kernel kern = GraphKernel(laplacian=L) + assert isinstance(kern, GraphKernel) assert kern.num_vertex == n_verticies assert kern.evals.shape == (n_verticies, 1) assert kern.evecs.shape == (n_verticies, n_verticies) diff --git a/tests/test_types.py b/tests/test_types.py index d4b11a94..f08f66bf 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,4 +1,4 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,32 +15,26 @@ import jax.numpy as jnp import pytest -from jax.config import config - -from gpjax.types import Dataset, NoneType, verify_dataset - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) - - -def test_nonetype(): - assert isinstance(None, NoneType) +from gpjax.types import Dataset, verify_dataset @pytest.mark.parametrize("n", [1, 10]) @pytest.mark.parametrize("outd", [1, 2, 10]) @pytest.mark.parametrize("ind", [1, 2, 10]) @pytest.mark.parametrize("n2", [1, 10]) -def test_dataset(n, outd, ind, n2): +def test_dataset(n: int, outd: int, ind: int, n2: int) -> None: x = jnp.ones((n, ind)) y = jnp.ones((n, outd)) d = Dataset(X=x, y=y) + verify_dataset(d) assert d.n == n assert d.in_dim == ind assert d.out_dim == outd - # test combine datasets + assert d.__repr__() == f"- Number of datapoints: {n}\n- Dimension: {ind}" + + # Test combine datasets. x2 = 2 * jnp.ones((n2, ind)) y2 = 2 * jnp.ones((n2, outd)) d2 = Dataset(X=x2, y=y2) @@ -54,17 +48,41 @@ def test_dataset(n, outd, ind, n2): assert (d_combined.X[:n] == 1.0).all() assert (d_combined.X[n:] == 2.0).all() + # Test supervised and unsupervised. + assert d.is_supervised() is True + dunsup = Dataset(y=y) + assert dunsup.is_unsupervised() is True + @pytest.mark.parametrize("nx, ny", [(1, 2), (2, 1), (10, 5), (5, 10)]) -def test_dataset_assertions(nx, ny): - x = jnp.ones((nx, 1)) - y = jnp.ones((ny, 1)) - with pytest.raises(AssertionError): +@pytest.mark.parametrize("outd", [1, 2, 10]) +@pytest.mark.parametrize("ind", [1, 2, 10]) +def test_dataset_assertions(nx: int, ny: int, outd: int, ind: int) -> None: + x = jnp.ones((nx, ind)) + y = jnp.ones((ny, outd)) + + with pytest.raises(ValueError): + ds = Dataset(X=x, y=y) + + +@pytest.mark.parametrize("n", [1, 2, 10]) +@pytest.mark.parametrize("outd", [1, 2, 10]) +@pytest.mark.parametrize("ind", [1, 2, 10]) +def test_2d_inputs(n: int, outd: int, ind: int) -> None: + x = jnp.ones((n, ind)) + y = jnp.ones((n,)) + + with pytest.raises(ValueError): + ds = Dataset(X=x, y=y) + + x = jnp.ones((n,)) + y = jnp.ones((n, outd)) + + with pytest.raises(ValueError): ds = Dataset(X=x, y=y) - verify_dataset(ds) -def test_y_none(): +def test_y_none() -> None: x = jnp.ones((10, 1)) d = Dataset(X=x) verify_dataset(d) diff --git a/tests/test_variational_inference.py b/tests/test_variational_inference.py index e310f275..1e7eb9eb 100644 --- a/tests/test_variational_inference.py +++ b/tests/test_variational_inference.py @@ -91,7 +91,6 @@ def test_stochastic_vi( assert svgp.prior == post.prior assert svgp.likelihood == post.likelihood - assert svgp.num_inducing == n_inducing_points if jit_fns: elbo_fn = jax.jit(svgp.elbo(D)) @@ -129,7 +128,6 @@ def test_collapsed_vi(n_datapoints, n_inducing_points, jit_fns, point_dim): assert sgpr.prior == post.prior assert sgpr.likelihood == post.likelihood - assert sgpr.num_inducing == n_inducing_points if jit_fns: elbo_fn = jax.jit(sgpr.elbo(D))