From 45519df63c02178375e4d45d98c76c7ad7a189d8 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Sun, 8 Jan 2023 15:38:09 +0000 Subject: [PATCH 1/5] ` init_params` revamp, remove test from `./gpjax` --- docs/_api.rst | 2 +- examples/graph_kernels.pct.py | 4 +- examples/haiku.pct.py | 4 +- examples/kernels.pct.py | 20 ++-- gpjax/gps.py | 46 ++++---- gpjax/kernels.py | 42 +++++--- gpjax/likelihoods.py | 17 ++- gpjax/mean_functions.py | 17 ++- gpjax/parameters.py | 2 +- gpjax/test_variational_inference.py | 159 ---------------------------- gpjax/variational_families.py | 31 ++++-- gpjax/variational_inference.py | 17 ++- tests/test_kernels.py | 34 +++--- tests/test_likelihoods.py | 10 +- tests/test_mean_functions.py | 6 +- tests/test_variational_families.py | 10 +- 16 files changed, 157 insertions(+), 264 deletions(-) delete mode 100644 gpjax/test_variational_inference.py diff --git a/docs/_api.rst b/docs/_api.rst index a26db0df..b942a7ff 100644 --- a/docs/_api.rst +++ b/docs/_api.rst @@ -30,7 +30,7 @@ process objects. .. autoclass:: AbstractPrior :members: :special-members: __call__ - :private-members: _initialise_params + :private-members: init_params :exclude-members: from_tuple, replace, to_tuple .. autoclass:: AbstractPosterior diff --git a/examples/graph_kernels.pct.py b/examples/graph_kernels.pct.py index 6cb52c65..769e25ae 100644 --- a/examples/graph_kernels.pct.py +++ b/examples/graph_kernels.pct.py @@ -85,7 +85,7 @@ kernel = jk.GraphKernel(laplacian=L) prior = gpx.Prior(kernel=kernel) -true_params = prior._initialise_params(key) +true_params = prior.init_params(key) true_params["kernel"] = { "lengthscale": jnp.array(2.3), "variance": jnp.array(3.2), @@ -101,7 +101,7 @@ kernel.compute_engine.gram # %% -kernel.gram(params=kernel._initialise_params(key), inputs=x) +kernel.gram(params=kernel.init_params(key), inputs=x) # %% [markdown] # diff --git a/examples/haiku.pct.py b/examples/haiku.pct.py index e98978f5..fc634e47 100644 --- a/examples/haiku.pct.py +++ b/examples/haiku.pct.py @@ -107,10 +107,10 @@ def __call__( def initialise(self, dummy_x: Float[Array, "1 D"], key: jr.KeyArray) -> None: nn_params = self.network.init(rng=key, x=dummy_x) - base_kernel_params = self.base_kernel._initialise_params(key) + base_kernel_params = self.base_kernel.init_params(key) self._params = {**nn_params, **base_kernel_params} - def _initialise_params(self, key: jr.KeyArray) -> Dict: + def init_params(self, key: jr.KeyArray) -> Dict: return self._params diff --git a/examples/kernels.pct.py b/examples/kernels.pct.py index c4408c6d..a9869daf 100644 --- a/examples/kernels.pct.py +++ b/examples/kernels.pct.py @@ -97,7 +97,7 @@ # %% print(f"ARD: {slice_kernel.ard}") -print(f"Lengthscales: {slice_kernel._initialise_params(key)['lengthscale']}") +print(f"Lengthscales: {slice_kernel.init_params(key)['lengthscale']}") # %% [markdown] # We'll now simulate some data and evaluate the kernel on the previously selected input dimensions. @@ -107,7 +107,7 @@ x_matrix = jr.normal(key, shape=(50, 5)) # Default parameter dictionary -params = slice_kernel._initialise_params(key) +params = slice_kernel.init_params(key) # Compute the Gram matrix K = slice_kernel.gram(params, x_matrix) @@ -127,9 +127,9 @@ sum_k = k1 + k2 fig, ax = plt.subplots(ncols=3, figsize=(20, 5)) -im0 = ax[0].matshow(k1.gram(k1._initialise_params(key), x).to_dense()) -im1 = ax[1].matshow(k2.gram(k2._initialise_params(key), x).to_dense()) -im2 = ax[2].matshow(sum_k.gram(sum_k._initialise_params(key), x).to_dense()) +im0 = ax[0].matshow(k1.gram(k1.init_params(key), x).to_dense()) +im1 = ax[1].matshow(k2.gram(k2.init_params(key), x).to_dense()) +im2 = ax[2].matshow(sum_k.gram(sum_k.init_params(key), x).to_dense()) fig.colorbar(im0, ax=ax[0]) fig.colorbar(im1, ax=ax[1]) @@ -144,10 +144,10 @@ prod_k = k1 * k2 * k3 fig, ax = plt.subplots(ncols=4, figsize=(20, 5)) -im0 = ax[0].matshow(k1.gram(k1._initialise_params(key), x).to_dense()) -im1 = ax[1].matshow(k2.gram(k2._initialise_params(key), x).to_dense()) -im2 = ax[2].matshow(k3.gram(k3._initialise_params(key), x).to_dense()) -im3 = ax[3].matshow(prod_k.gram(prod_k._initialise_params(key), x).to_dense()) +im0 = ax[0].matshow(k1.gram(k1.init_params(key), x).to_dense()) +im1 = ax[1].matshow(k2.gram(k2.init_params(key), x).to_dense()) +im2 = ax[2].matshow(k3.gram(k3.init_params(key), x).to_dense()) +im3 = ax[3].matshow(prod_k.gram(prod_k.init_params(key), x).to_dense()) fig.colorbar(im0, ax=ax[0]) fig.colorbar(im1, ax=ax[1]) @@ -218,7 +218,7 @@ def __call__( K = (1 + tau * t / self.c) * jnp.clip(1 - t / self.c, 0, jnp.inf) ** tau return K.squeeze() - def _initialise_params(self, key: jr.PRNGKey) -> dict: + def init_params(self, key: jr.PRNGKey) -> dict: return {"tau": jnp.array([4.0])} diff --git a/gpjax/gps.py b/gpjax/gps.py index a3b0fd7c..e38ad954 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -33,16 +33,13 @@ from .utils import concat_dictionaries from .gaussian_distribution import GaussianDistribution +import deprecation + class AbstractPrior(PyTree): """Abstract Gaussian process prior. - All Gaussian processes priors should inherit from this class. - - All GPJax Modules are `Chex dataclasses `_. Since - dataclasses take over ``__init__``, the ``__post_init__`` method can be used to - initialise the GP's parameters. - """ + All Gaussian processes priors should inherit from this class.""" def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: """Evaluate the Gaussian process at the given points. The output of this function @@ -79,7 +76,7 @@ def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: raise NotImplementedError @abstractmethod - def _initialise_params(self, key: KeyArray) -> Dict: + def init_params(self, key: KeyArray) -> Dict: """An initialisation method for the GP's parameters. This method should be implemented for all classes that inherit the ``AbstractPrior`` class. Whilst not always necessary, the method accepts a PRNG key to allow @@ -94,6 +91,15 @@ def _initialise_params(self, key: KeyArray) -> Dict: """ raise NotImplementedError + @deprecation.deprecated( + deprecated_in="0.5.7", + removed_in="0.6.0", + details="Use the ``init_params`` method for parameter initialisation.", + ) + def _initialise_params(self, key: KeyArray) -> Dict: + """Deprecated method for initialising the GP's parameters. Succeded by ``init_params``.""" + return self.init_params(key) + ####################### # GP Priors @@ -239,7 +245,7 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: return predict_fn - def _initialise_params(self, key: KeyArray) -> Dict: + def init_params(self, key: KeyArray) -> Dict: """Initialise the GP prior's parameter set. Args: @@ -249,8 +255,8 @@ def _initialise_params(self, key: KeyArray) -> Dict: Dict: The initialised parameter set. """ return { - "kernel": self.kernel._initialise_params(key), - "mean_function": self.mean_function._initialise_params(key), + "kernel": self.kernel.init_params(key), + "mean_function": self.mean_function.init_params(key), } @@ -259,13 +265,7 @@ def _initialise_params(self, key: KeyArray) -> Dict: ####################### class AbstractPosterior(AbstractPrior): """The base GP posterior object conditioned on an observed dataset. All - posterior objects should inherit from this class. - - All GPJax Modules are `Chex dataclasses - `_. Since dataclasses - take over ``__init__``, the ``__post_init__`` method can be used to - initialise the GP's parameters. - """ + posterior objects should inherit from this class.""" def __init__( self, @@ -300,7 +300,7 @@ def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: """ raise NotImplementedError - def _initialise_params(self, key: KeyArray) -> Dict: + def init_params(self, key: KeyArray) -> Dict: """Initialise the parameter set of a GP posterior. Args: @@ -310,8 +310,8 @@ def _initialise_params(self, key: KeyArray) -> Dict: Dict: The initialised parameter set. """ return concat_dictionaries( - self.prior._initialise_params(key), - {"likelihood": self.likelihood._initialise_params(key)}, + self.prior.init_params(key), + {"likelihood": self.likelihood.init_params(key)}, ) @@ -611,7 +611,7 @@ def __init__( self.likelihood = likelihood self.name = name - def _initialise_params(self, key: KeyArray) -> Dict: + def init_params(self, key: KeyArray) -> Dict: """Initialise the parameter set of a non-conjugate GP posterior. Args: @@ -621,8 +621,8 @@ def _initialise_params(self, key: KeyArray) -> Dict: Dict: A dictionary containing the default parameter set. """ parameters = concat_dictionaries( - self.prior._initialise_params(key), - {"likelihood": self.likelihood._initialise_params(key)}, + self.prior.init_params(key), + {"likelihood": self.likelihood.init_params(key)}, ) parameters["latent"] = jnp.zeros(shape=(self.likelihood.num_datapoints, 1)) return parameters diff --git a/gpjax/kernels.py b/gpjax/kernels.py index 460c5461..7b0b9759 100644 --- a/gpjax/kernels.py +++ b/gpjax/kernels.py @@ -27,8 +27,7 @@ from jax import vmap import jax from jaxtyping import Array, Float - -from chex import PRNGKey as PRNGKeyType +from jax.random import KeyArray from jaxutils import PyTree import deprecation @@ -352,11 +351,11 @@ def ard(self): return True if self.ndims > 1 else False @abc.abstractmethod - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def init_params(self, key: KeyArray) -> Dict: """A template dictionary of the kernel's parameter set. Args: - key (PRNGKeyType): A PRNG key to be used for initialising + key (KeyArray): A PRNG key to be used for initialising the kernel's parameters. Returns: @@ -364,6 +363,15 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: """ raise NotImplementedError + @deprecation.deprecated( + deprecated_in="0.5.7", + removed_in="0.6.0", + details="Use the ``init_params`` method for parameter initialisation.", + ) + def _initialise_params(self, key: KeyArray) -> Dict: + """Deprecated method for initialising the GP's parameters. Succeded by ``init_params``.""" + return self.init_params(key) + class CombinationKernel(AbstractKernel): """A base class for products or sums of kernels.""" @@ -399,9 +407,9 @@ def _set_kernels(self, kernels: Sequence[AbstractKernel]) -> None: self.kernel_set = kernels_list - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def init_params(self, key: KeyArray) -> Dict: """A template dictionary of the kernel's parameter set.""" - return [kernel._initialise_params(key) for kernel in self.kernel_set] + return [kernel.init_params(key) for kernel in self.kernel_set] def __call__( self, @@ -501,7 +509,7 @@ def __call__( K = params["variance"] * jnp.exp(-0.5 * squared_distance(x, y)) return K.squeeze() - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def init_params(self, key: KeyArray) -> Dict: params = { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), @@ -551,7 +559,7 @@ def __call__( K = params["variance"] * jnp.exp(-euclidean_distance(x, y)) return K.squeeze() - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def init_params(self, key: KeyArray) -> Dict: return { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), @@ -606,7 +614,7 @@ def __call__( ) return K.squeeze() - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def init_params(self, key: KeyArray) -> Dict: return { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), @@ -658,7 +666,7 @@ def __call__( ) return K.squeeze() - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def init_params(self, key: KeyArray) -> Dict: return { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), @@ -706,7 +714,7 @@ def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: K = params["variance"] * jnp.exp(-euclidean_distance(x, y) ** params["power"]) return K.squeeze() - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def init_params(self, key: KeyArray) -> Dict: return { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), @@ -750,7 +758,7 @@ def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: K = params["variance"] * jnp.matmul(x.T, y) return K.squeeze() - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def init_params(self, key: KeyArray) -> Dict: return {"variance": jnp.array([1.0])} @@ -796,7 +804,7 @@ def __call__( K = jnp.power(params["shift"] + jnp.dot(x * params["variance"], y), self.degree) return K.squeeze() - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def init_params(self, key: KeyArray) -> Dict: return { "shift": jnp.array([1.0]), "variance": jnp.array([1.0] * self.ndims), @@ -841,7 +849,7 @@ def __call__( K = jnp.all(jnp.equal(x, y)) * params["variance"] return K.squeeze() - def _initialise_params(self, key: Float[Array, "1 D"]) -> Dict: + def init_params(self, key: Float[Array, "1 D"]) -> Dict: """Initialise the kernel parameters. Args: @@ -889,7 +897,7 @@ def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: ) ** (-params["alpha"]) return K.squeeze() - def _initialise_params(self, key: PRNGKeyType) -> dict: + def init_params(self, key: KeyArray) -> dict: return { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), @@ -939,7 +947,7 @@ def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: K = params["variance"] * jnp.exp(-0.5 * jnp.sum(sine_squared, axis=0)) return K.squeeze() - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def init_params(self, key: KeyArray) -> Dict: return { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), @@ -1044,7 +1052,7 @@ def __call__( ) # shape (n,n) return Kxx.squeeze() - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def init_params(self, key: KeyArray) -> Dict: return { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index 898099f6..3cc96be5 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -25,6 +25,8 @@ from jax.random import KeyArray +import deprecation + class AbstractLikelihood(PyTree): """Abstract base class for likelihoods.""" @@ -65,7 +67,7 @@ def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: raise NotImplementedError @abc.abstractmethod - def _initialise_params(self, key: KeyArray) -> Dict: + def init_params(self, key: KeyArray) -> Dict: """Return the parameters of the likelihood function. Args: @@ -76,6 +78,15 @@ def _initialise_params(self, key: KeyArray) -> Dict: """ raise NotImplementedError + @deprecation.deprecated( + deprecated_in="0.5.7", + removed_in="0.6.0", + details="Use the ``init_params`` method for parameter initialisation.", + ) + def _initialise_params(self, key: KeyArray) -> Dict: + """Deprecated method for initialising the GP's parameters. Succeded by ``init_params``.""" + return self.init_params(key) + @property @abc.abstractmethod def link_function(self) -> Callable: @@ -110,7 +121,7 @@ def __init__(self, num_datapoints: int, name: Optional[str] = "Gaussian"): """ super().__init__(num_datapoints, name) - def _initialise_params(self, key: KeyArray) -> Dict: + def init_params(self, key: KeyArray) -> Dict: """Return the variance parameter of the likelihood function. Args: @@ -179,7 +190,7 @@ def __init__(self, num_datapoints: int, name: Optional[str] = "Bernoulli"): """ super().__init__(num_datapoints, name) - def _initialise_params(self, key: KeyArray) -> Dict: + def init_params(self, key: KeyArray) -> Dict: """Initialise the parameter set of a Bernoulli likelihood. Args: diff --git a/gpjax/mean_functions.py b/gpjax/mean_functions.py index 65902640..e6073a54 100644 --- a/gpjax/mean_functions.py +++ b/gpjax/mean_functions.py @@ -21,6 +21,8 @@ from jaxtyping import Array, Float from jaxutils import PyTree +import deprecation + class AbstractMeanFunction(PyTree): """Abstract mean function that is used to parameterise the Gaussian process.""" @@ -51,7 +53,7 @@ def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: raise NotImplementedError @abc.abstractmethod - def _initialise_params(self, key: KeyArray) -> Dict: + def init_params(self, key: KeyArray) -> Dict: """Return the parameters of the mean function. This method is required for all subclasses. Args: @@ -62,6 +64,15 @@ def _initialise_params(self, key: KeyArray) -> Dict: """ raise NotImplementedError + @deprecation.deprecated( + deprecated_in="0.5.7", + removed_in="0.6.0", + details="Use the ``init_params`` method for parameter initialisation.", + ) + def _initialise_params(self, key: KeyArray) -> Dict: + """Deprecated method for initialising the GP's parameters. Succeded by ``init_params``.""" + return self.init_params(key) + class Zero(AbstractMeanFunction): """ @@ -92,7 +103,7 @@ def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: out_shape = (x.shape[0], self.output_dim) return jnp.zeros(shape=out_shape) - def _initialise_params(self, key: KeyArray) -> Dict: + def init_params(self, key: KeyArray) -> Dict: """The parameters of the mean function. For the zero-mean function, this is an empty dictionary. Args: @@ -134,7 +145,7 @@ def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: out_shape = (x.shape[0], self.output_dim) return jnp.ones(shape=out_shape) * params["constant"] - def _initialise_params(self, key: KeyArray) -> Dict: + def init_params(self, key: KeyArray) -> Dict: """The parameters of the mean function. For the constant-mean function, this is a dictionary with a single value. Args: diff --git a/gpjax/parameters.py b/gpjax/parameters.py index f409905d..5835ab6f 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -74,7 +74,7 @@ def initialise(model, key: KeyArray = None, **kwargs) -> ParameterState: if key is None: warn("No PRNGKey specified. Defaulting to seed 123.", UserWarning, stacklevel=2) key = jr.PRNGKey(123) - params = model._initialise_params(key) + params = model.init_params(key) if kwargs: _validate_kwargs(kwargs, params) diff --git a/gpjax/test_variational_inference.py b/gpjax/test_variational_inference.py deleted file mode 100644 index 1e7eb9eb..00000000 --- a/gpjax/test_variational_inference.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import typing as tp - -import jax -import jax.numpy as jnp -import jax.random as jr -import pytest -from jax.config import config - -import gpjax as gpx -from gpjax.variational_families import ( - CollapsedVariationalGaussian, - ExpectationVariationalGaussian, - NaturalVariationalGaussian, - VariationalGaussian, - WhitenedVariationalGaussian, -) - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) - - -def test_abstract_variational_inference(): - prior = gpx.Prior(kernel=gpx.RBF()) - lik = gpx.Gaussian(num_datapoints=20) - post = prior * lik - n_inducing_points = 10 - inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing_points).reshape(-1, 1) - vartiational_family = gpx.VariationalGaussian( - prior=prior, inducing_inputs=inducing_inputs - ) - - with pytest.raises(TypeError): - gpx.variational_inference.AbstractVariationalInference( - posterior=post, vartiational_family=vartiational_family - ) - - -def get_data_and_gp(n_datapoints, point_dim): - x = jnp.linspace(-5.0, 5.0, n_datapoints).reshape(-1, 1) - y = jnp.sin(x) + jr.normal(key=jr.PRNGKey(123), shape=x.shape) * 0.1 - x = jnp.hstack([x] * point_dim) - D = gpx.Dataset(X=x, y=y) - - p = gpx.Prior(kernel=gpx.RBF()) - lik = gpx.Gaussian(num_datapoints=n_datapoints) - post = p * lik - return D, post, p - - -@pytest.mark.parametrize("n_datapoints, n_inducing_points", [(10, 2), (100, 10)]) -@pytest.mark.parametrize("jit_fns", [False, True]) -@pytest.mark.parametrize("point_dim", [1, 2, 3]) -@pytest.mark.parametrize( - "variational_family", - [ - VariationalGaussian, - WhitenedVariationalGaussian, - NaturalVariationalGaussian, - ExpectationVariationalGaussian, - ], -) -def test_stochastic_vi( - n_datapoints, n_inducing_points, jit_fns, point_dim, variational_family -): - D, post, prior = get_data_and_gp(n_datapoints, point_dim) - inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing_points).reshape(-1, 1) - inducing_inputs = jnp.hstack([inducing_inputs] * point_dim) - - q = variational_family(prior=prior, inducing_inputs=inducing_inputs) - - svgp = gpx.StochasticVI(posterior=post, variational_family=q) - assert svgp.posterior.prior == post.prior - assert svgp.posterior.likelihood == post.likelihood - - params, _, _ = gpx.initialise(svgp, jr.PRNGKey(123)).unpack() - - assert svgp.prior == post.prior - assert svgp.likelihood == post.likelihood - - if jit_fns: - elbo_fn = jax.jit(svgp.elbo(D)) - else: - elbo_fn = svgp.elbo(D) - assert isinstance(elbo_fn, tp.Callable) - elbo_value = elbo_fn(params, D) - assert isinstance(elbo_value, jnp.ndarray) - - # Test gradients - grads = jax.grad(elbo_fn, argnums=0)(params, D) - assert isinstance(grads, tp.Dict) - assert len(grads) == len(params) - - -@pytest.mark.parametrize("n_datapoints, n_inducing_points", [(10, 2), (100, 10)]) -@pytest.mark.parametrize("jit_fns", [False, True]) -@pytest.mark.parametrize("point_dim", [1, 2]) -def test_collapsed_vi(n_datapoints, n_inducing_points, jit_fns, point_dim): - D, post, prior = get_data_and_gp(n_datapoints, point_dim) - likelihood = gpx.Gaussian(num_datapoints=n_datapoints) - - inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing_points).reshape(-1, 1) - inducing_inputs = jnp.hstack([inducing_inputs] * point_dim) - - q = CollapsedVariationalGaussian( - prior=prior, likelihood=likelihood, inducing_inputs=inducing_inputs - ) - - sgpr = gpx.variational_inference.CollapsedVI(posterior=post, variational_family=q) - assert sgpr.posterior.prior == post.prior - assert sgpr.posterior.likelihood == post.likelihood - - params, _, _ = gpx.initialise(sgpr, jr.PRNGKey(123)).unpack() - - assert sgpr.prior == post.prior - assert sgpr.likelihood == post.likelihood - - if jit_fns: - elbo_fn = jax.jit(sgpr.elbo(D)) - else: - elbo_fn = sgpr.elbo(D) - assert isinstance(elbo_fn, tp.Callable) - elbo_value = elbo_fn(params) - assert isinstance(elbo_value, jnp.ndarray) - - # Test gradients - grads = jax.grad(elbo_fn)(params) - assert isinstance(grads, tp.Dict) - assert len(grads) == len(params) - - # We should raise an error for non-Collapsed variational families: - with pytest.raises(TypeError): - q = gpx.variational_families.VariationalGaussian( - prior=prior, inducing_inputs=inducing_inputs - ) - gpx.variational_inference.CollapsedVI(posterior=post, variational_family=q) - - # We should raise an error for non-Gaussian likelihoods: - with pytest.raises(TypeError): - q = gpx.variational_families.CollapsedVariationalGaussian( - prior=prior, likelihood=likelihood, inducing_inputs=inducing_inputs - ) - gpx.variational_inference.CollapsedVI( - posterior=prior * gpx.Bernoulli(num_datapoints=D.n), variational_family=q - ) diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index ac7e6ccc..7416fc99 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -32,6 +32,8 @@ from .utils import concat_dictionaries from .gaussian_distribution import GaussianDistribution +import deprecation + class AbstractVariationalFamily(PyTree): """ @@ -54,7 +56,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> GaussianDistribution: return self.predict(*args, **kwargs) @abc.abstractmethod - def _initialise_params(self, key: KeyArray) -> Dict: + def init_params(self, key: KeyArray) -> Dict: """ The parameters of the distribution. For example, the multivariate Gaussian would return a mean vector and covariance matrix. @@ -67,6 +69,15 @@ def _initialise_params(self, key: KeyArray) -> Dict: """ raise NotImplementedError + @deprecation.deprecated( + deprecated_in="0.5.7", + removed_in="0.6.0", + details="Use the ``init_params`` method for parameter initialisation.", + ) + def _initialise_params(self, key: KeyArray) -> Dict: + """Deprecated method for initialising the GP's parameters. Succeded by ``init_params``.""" + return self.init_params(key) + @abc.abstractmethod def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: """Predict the GP's output given the input. @@ -114,7 +125,7 @@ class VariationalGaussian(AbstractVariationalGaussian): :math:`\\mu` and sqrt with S = sqrt sqrtᵀ. """ - def _initialise_params(self, key: KeyArray) -> Dict: + def init_params(self, key: KeyArray) -> Dict: """ Return the variational mean vector, variational root covariance matrix, and inducing input vector that parameterise the variational Gaussian @@ -129,7 +140,7 @@ def _initialise_params(self, key: KeyArray) -> Dict: m = self.num_inducing return concat_dictionaries( - self.prior._initialise_params(key), + self.prior.init_params(key), { "variational_family": { "inducing_inputs": self.inducing_inputs, @@ -399,13 +410,13 @@ def __init__( super().__init__(prior, inducing_inputs, name) - def _initialise_params(self, key: KeyArray) -> Dict: + def init_params(self, key: KeyArray) -> Dict: """Return the natural vector and matrix, inducing inputs, and hyperparameters that parameterise the natural Gaussian distribution.""" m = self.num_inducing return concat_dictionaries( - self.prior._initialise_params(key), + self.prior.init_params(key), { "variational_family": { "inducing_inputs": self.inducing_inputs, @@ -584,7 +595,7 @@ def __init__( super().__init__(prior, inducing_inputs, name) - def _initialise_params(self, key: KeyArray) -> Dict: + def init_params(self, key: KeyArray) -> Dict: """Return the expectation vector and matrix, inducing inputs, and hyperparameters that parameterise the expectation Gaussian distribution.""" self.num_inducing = self.inducing_inputs.shape[0] @@ -592,7 +603,7 @@ def _initialise_params(self, key: KeyArray) -> Dict: m = self.num_inducing return concat_dictionaries( - self.prior._initialise_params(key), + self.prior.init_params(key), { "variational_family": { "inducing_inputs": self.inducing_inputs, @@ -764,14 +775,14 @@ def __init__( self.num_inducing = self.inducing_inputs.shape[0] self.name = name - def _initialise_params(self, key: KeyArray) -> Dict: + def init_params(self, key: KeyArray) -> Dict: """Return the variational mean vector, variational root covariance matrix, and inducing input vector that parameterise the variational Gaussian distribution.""" return concat_dictionaries( - self.prior._initialise_params(key), + self.prior.init_params(key), { "variational_family": {"inducing_inputs": self.inducing_inputs}, "likelihood": { - "obs_noise": self.likelihood._initialise_params(key)["obs_noise"] + "obs_noise": self.likelihood.init_params(key)["obs_noise"] }, }, ) diff --git a/gpjax/variational_inference.py b/gpjax/variational_inference.py index 6745656e..036c9ffe 100644 --- a/gpjax/variational_inference.py +++ b/gpjax/variational_inference.py @@ -36,6 +36,8 @@ CollapsedVariationalGaussian, ) +import deprecation + class AbstractVariationalInference(PyTree): """A base class for inference and training of variational families against an extact posterior""" @@ -56,14 +58,23 @@ def __init__( self.likelihood = self.posterior.likelihood self.variational_family = variational_family - def _initialise_params(self, key: KeyArray) -> Dict: + def init_params(self, key: KeyArray) -> Dict: """Construct the parameter set used within the variational scheme adopted.""" hyperparams = concat_dictionaries( - {"likelihood": self.posterior.likelihood._initialise_params(key)}, - self.variational_family._initialise_params(key), + {"likelihood": self.posterior.likelihood.init_params(key)}, + self.variational_family.init_params(key), ) return hyperparams + @deprecation.deprecated( + deprecated_in="0.5.7", + removed_in="0.6.0", + details="Use the ``init_params`` method for parameter initialisation.", + ) + def _initialise_params(self, key: KeyArray) -> Dict: + """Deprecated method for initialising the GP's parameters. Succeded by ``init_params``.""" + return self.init_params(key) + @abc.abstractmethod def elbo( self, diff --git a/tests/test_kernels.py b/tests/test_kernels.py index df4c8aa0..c111a156 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -64,12 +64,12 @@ def __call__( ) -> Float[Array, "1"]: return x * params["test"] * y - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def init_params(self, key: PRNGKeyType) -> Dict: return {"test": 1.0} # Initialise dummy kernel class and test __call__ and _init_params methods: dummy_kernel = DummyKernel() - assert dummy_kernel._initialise_params(_initialise_key) == {"test": 1.0} + assert dummy_kernel.init_params(_initialise_key) == {"test": 1.0} assert dummy_kernel(jnp.array([1.0]), jnp.array([2.0]), {"test": 2.0}) == 4.0 @@ -116,7 +116,7 @@ def test_gram(kernel: AbstractKernel, dim: int, n: int) -> None: x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) # Default kernel parameters: - params = kernel._initialise_params(_initialise_key) + params = kernel.init_params(_initialise_key) # Test gram matrix: Kxx = kernel.gram(params, x) @@ -147,7 +147,7 @@ def test_cross_covariance( b = jnp.linspace(3.0, 4.0, num_b * dim).reshape(num_b, dim) # Default kernel parameters: - params = kernel._initialise_params(_initialise_key) + params = kernel.init_params(_initialise_key) # Test cross covariance, Kab: Kab = kernel.cross_covariance(params, a, b) @@ -164,7 +164,7 @@ def test_call(kernel: AbstractKernel, dim: int) -> None: y = jnp.array([[0.5] * dim]) # Defualt parameters: - params = kernel._initialise_params(_initialise_key) + params = kernel.init_params(_initialise_key) # Test calling gives an autocovariance value of no dimension between the inputs: kxy = kernel(params, x, y) @@ -302,7 +302,7 @@ def test_initialisation(kernel: AbstractKernel, dim: int) -> None: else: kern = kernel(active_dims=[i for i in range(dim)]) - params = kern._initialise_params(_initialise_key) + params = kern.init_params(_initialise_key) assert list(params.keys()) == ["lengthscale", "variance"] assert all(params["lengthscale"] == jnp.array([1.0] * dim)) @@ -355,7 +355,7 @@ def test_polynomial( assert kern.name == f"Polynomial Degree: {degree}" # Initialise parameters - params = kern._initialise_params(_initialise_key) + params = kern.init_params(_initialise_key) params["shift"] * shift params["variance"] * variance @@ -397,8 +397,8 @@ def test_active_dim(kernel: AbstractKernel) -> None: manual_kern = kernel(active_dims=[i for i in range(perm_length)]) # Get initial parameters - ad_params = ad_kern._initialise_params(_initialise_key) - manual_params = manual_kern._initialise_params(_initialise_key) + ad_params = ad_kern.init_params(_initialise_key) + manual_params = manual_kern.init_params(_initialise_key) # Compute gram matrices ad_Kxx = ad_kern.gram(ad_params, x) @@ -429,7 +429,7 @@ def test_combination_kernel( combination_kernel = combination_type(kernel_set=kernel_set) # Initialise default parameters - params = combination_kernel._initialise_params(_initialise_key) + params = combination_kernel.init_params(_initialise_key) # Check params are a list of dictionaries assert len(params) == n_kerns @@ -470,15 +470,15 @@ def test_sum_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: sum_kernel = SumKernel(kernel_set=[k1, k2]) # Initialise default parameters - params = sum_kernel._initialise_params(_initialise_key) + params = sum_kernel.init_params(_initialise_key) # Compute gram matrix Kxx = sum_kernel.gram(params, x) # NOW we do the same thing manually and check they are equal: # Initialise default parameters - k1_params = k1._initialise_params(_initialise_key) - k2_params = k2._initialise_params(_initialise_key) + k1_params = k1.init_params(_initialise_key) + k2_params = k2.init_params(_initialise_key) # Compute gram matrix Kxx_k1 = k1.gram(k1_params, x) @@ -524,7 +524,7 @@ def test_prod_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: prod_kernel = ProductKernel(kernel_set=[k1, k2]) # Initialise default parameters - params = prod_kernel._initialise_params(_initialise_key) + params = prod_kernel.init_params(_initialise_key) # Compute gram matrix Kxx = prod_kernel.gram(params, x) @@ -532,8 +532,8 @@ def test_prod_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: # NOW we do the same thing manually and check they are equal: # Initialise default parameters - k1_params = k1._initialise_params(_initialise_key) - k2_params = k2._initialise_params(_initialise_key) + k1_params = k1.init_params(_initialise_key) + k2_params = k2.init_params(_initialise_key) # Compute gram matrix Kxx_k1 = k1.gram(k1_params, x) @@ -564,7 +564,7 @@ def test_graph_kernel(): kern.gram # Initialise default parameters - params = kern._initialise_params(_initialise_key) + params = kern.init_params(_initialise_key) assert isinstance(params, dict) assert list(sorted(list(params.keys()))) == [ "lengthscale", diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index 4d1d0dd0..d27770b3 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -20,7 +20,7 @@ import jax.random as jr import numpy as np import pytest -from chex import PRNGKey as PRNGKeyType +from jax.random import KeyArray from jax.config import config from jaxtyping import Array, Float @@ -52,7 +52,7 @@ def test_abstract_likelihood(): # Create a dummy likelihood class with abstract methods implemented. class DummyLikelihood(AbstractLikelihood): - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def init_params(self, key: KeyArray) -> Dict: return {} def predict(self, params: Dict, dist: dx.Distribution) -> dx.Distribution: @@ -78,7 +78,7 @@ def test_initialisers(n: int, lik: AbstractLikelihood) -> None: likelihood = lik(num_datapoints=n) # Get default parameter dictionary. - params = likelihood._initialise_params(key) + params = likelihood.init_params(key) # Check parameter dictionary assert list(params.keys()) == true_initialisation[likelihood.name] @@ -93,7 +93,7 @@ def test_bernoulli_predictive_moment(n: int) -> None: likelihood = Bernoulli(num_datapoints=n) # Initialise parameters. - params = likelihood._initialise_params(key) + params = likelihood.init_params(key) # Construct latent function mean and variance values mean_key, var_key = jr.split(key) @@ -123,7 +123,7 @@ def test_link_fns(lik: AbstractLikelihood, n: int) -> None: likelihood = lik(num_datapoints=n) # Initialise parameters. - params = likelihood._initialise_params(key) + params = likelihood.init_params(key) # Test likelihood link function. assert isinstance(likelihood.link_function, Callable) diff --git a/tests/test_mean_functions.py b/tests/test_mean_functions.py index cd98d428..46acdacb 100644 --- a/tests/test_mean_functions.py +++ b/tests/test_mean_functions.py @@ -18,7 +18,7 @@ import jax.numpy as jnp import jax.random as jr import pytest -from chex import PRNGKey as PRNGKeyType +from jax.random import KeyArray from jax.config import config from jaxtyping import Array, Float @@ -39,7 +39,7 @@ class DummyMeanFunction(AbstractMeanFunction): def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N 1"]: return jnp.ones((x.shape[0], 1)) - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def init_params(self, key: KeyArray) -> Dict: return {} # Test that the dummy mean function can be instantiated. @@ -60,7 +60,7 @@ def test_shape(mean_function: AbstractMeanFunction, n: int, dim: int) -> None: mf = mean_function(output_dim=dim) # Initialise parameters. - params = mf._initialise_params(key) + params = mf.init_params(key) assert isinstance(params, dict) # Test shape of mean function. diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index 86e8bda0..c3c507ab 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -47,7 +47,7 @@ class DummyVariationalFamily(AbstractVariationalFamily): def predict(self, params: Dict, x: Float[Array, "N D"]) -> dx.Distribution: return dx.MultivariateNormalDiag(loc=x) - def _initialise_params(self, key: jr.PRNGKey) -> dict: + def init_params(self, key: jr.PRNGKey) -> dict: return {} # Test that the dummy variational family can be instantiated. @@ -138,7 +138,7 @@ def test_variational_gaussians( assert isinstance(q, AbstractVariationalFamily) # Test params and keys: - params = q._initialise_params(jr.PRNGKey(123)) + params = q.init_params(jr.PRNGKey(123)) assert isinstance(params, dict) config_params = gpx.config.get_global_config() @@ -161,7 +161,7 @@ def test_variational_gaussians( assert (moment == value(n_inducing)).all() # Test KL - params = q._initialise_params(jr.PRNGKey(123)) + params = q.init_params(jr.PRNGKey(123)) kl = q.prior_kl(params) assert isinstance(kl, jnp.ndarray) @@ -221,7 +221,7 @@ def test_collapsed_variational_gaussian( assert (variational_family.inducing_inputs == inducing_inputs).all() # Test params - params = variational_family._initialise_params(jr.PRNGKey(123)) + params = variational_family.init_params(jr.PRNGKey(123)) assert isinstance(params, dict) assert "likelihood" in params.keys() assert "obs_noise" in params["likelihood"].keys() @@ -233,7 +233,7 @@ def test_collapsed_variational_gaussian( assert isinstance(params["variational_family"]["inducing_inputs"], jax.Array) # Test predictions - params = variational_family._initialise_params(jr.PRNGKey(123)) + params = variational_family.init_params(jr.PRNGKey(123)) predictive_dist_fn = variational_family(params, D) assert isinstance(predictive_dist_fn, Callable) From f72534cf247deb1cc1f8cf43b975a2395a6706ed Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Sun, 8 Jan 2023 15:59:45 +0000 Subject: [PATCH 2/5] Update parameters.py --- gpjax/parameters.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 5835ab6f..69e74d48 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -74,7 +74,21 @@ def initialise(model, key: KeyArray = None, **kwargs) -> ParameterState: if key is None: warn("No PRNGKey specified. Defaulting to seed 123.", UserWarning, stacklevel=2) key = jr.PRNGKey(123) - params = model.init_params(key) + + # Initialise the parameters. + if hasattr(model, "init_params"): + params = model.init_params(key) + + elif hasattr(model, "_initialise_params"): + warn( + "`_initialise_params` is deprecated. Please use `init_params` instead.", + DeprecationWarning, + stacklevel=2, + ) + params = model._initialise_params(key) + + else: + raise AttributeError("No `init_params` or `_initialise_params` method found.") if kwargs: _validate_kwargs(kwargs, params) From 1c3a29f14cb1503a17afe736676bd21b34a38c6e Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Sun, 8 Jan 2023 16:10:58 +0000 Subject: [PATCH 3/5] Resolve JaxKern issues. --- examples/haiku.pct.py | 4 ++++ examples/kernels.pct.py | 8 ++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/haiku.pct.py b/examples/haiku.pct.py index fc634e47..b5dccba0 100644 --- a/examples/haiku.pct.py +++ b/examples/haiku.pct.py @@ -113,6 +113,10 @@ def initialise(self, dummy_x: Float[Array, "1 D"], key: jr.KeyArray) -> None: def init_params(self, key: jr.KeyArray) -> Dict: return self._params + # This is depreciated. Can be removed once JaxKern is updated. + def _initialise_params(self, key: jr.KeyArray) -> Dict: + return self.init_params(key) + # %% [markdown] # ### Defining a network diff --git a/examples/kernels.pct.py b/examples/kernels.pct.py index a9869daf..1e830075 100644 --- a/examples/kernels.pct.py +++ b/examples/kernels.pct.py @@ -9,7 +9,7 @@ # format_version: '1.3' # jupytext_version: 1.11.2 # kernelspec: -# display_name: Python 3.9.7 ('gpjax') +# display_name: base # language: python # name: python3 # --- @@ -218,9 +218,13 @@ def __call__( K = (1 + tau * t / self.c) * jnp.clip(1 - t / self.c, 0, jnp.inf) ** tau return K.squeeze() - def init_params(self, key: jr.PRNGKey) -> dict: + def init_params(self, key: jr.KeyArray) -> dict: return {"tau": jnp.array([4.0])} + # This is depreciated. Can be removed once JaxKern is updated. + def _initialise_params(self, key: jr.KeyArray) -> Dict: + return self.init_params(key) + # %% [markdown] # We unpack this now to make better sense of it. In the kernel's `__init__` From 5911a444259c56ab44606a1069d2213c842ca1ec Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Sun, 8 Jan 2023 22:33:57 +0000 Subject: [PATCH 4/5] Depreciate config to jaxutils. --- gpjax/config.py | 157 ++++++++---------------------------------------- gpjax/utils.py | 90 ++++----------------------- 2 files changed, 37 insertions(+), 210 deletions(-) diff --git a/gpjax/config.py b/gpjax/config.py index cabbbfd3..26c592db 100644 --- a/gpjax/config.py +++ b/gpjax/config.py @@ -13,138 +13,33 @@ # limitations under the License. # ============================================================================== -import jax -import distrax as dx -import jax.numpy as jnp -import jax.random as jr -import tensorflow_probability.substrates.jax.bijectors as tfb -from ml_collections import ConfigDict -__config = None +import deprecation -Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x) -Softplus = dx.Lambda( - forward=lambda x: jnp.log(1 + jnp.exp(x)), - inverse=lambda x: jnp.log(jnp.exp(x) - 1.0), +depreciate = deprecation.deprecated( + deprecated_in="0.5.6", + removed_in="0.6.0", + details="Use method from jaxutils.config instead.", ) - -def reset_global_config() -> None: - global __config - __config = get_default_config() - - -def get_global_config() -> ConfigDict: - """Get the global config file used within GPJax. - - Returns: - ConfigDict: A `ConfigDict` describing parameter transforms and default values. - """ - global __config - - if __config is None: - __config = get_default_config() - return __config - - # If the global config is available, check if the x64 state has changed - x64_state = jax.config.x64_enabled - - # If the x64 state has not changed, return the existing global config - if x64_state is __config.x64_state: - return __config - - # If the x64 state has changed, return the updated global config - update_x64_sensitive_settings() - return __config - - -def update_x64_sensitive_settings() -> None: - """Update the global config if x64 state changes.""" - global __config - - # Update the x64 state - x64_state = jax.config.x64_enabled - __config.x64_state = x64_state - - # Update the x64 sensitive bijectors - FillScaleTriL = dx.Chain( - [ - tfb.FillScaleTriL(diag_shift=jnp.array(__config.jitter)), - ] - ) - - transformations = __config.transformations - transformations.triangular_transform = FillScaleTriL - - -def get_default_config() -> ConfigDict: - """Construct and return the default config file. - - Returns: - ConfigDict: A `ConfigDict` describing parameter transforms and default values. - """ - - config = ConfigDict(type_safe=False) - config.key = jr.PRNGKey(123) - - # Set the x64 state - config.x64_state = jax.config.x64_enabled - - # Covariance matrix stabilising jitter - config.jitter = 1e-6 - - FillScaleTriL = dx.Chain( - [ - tfb.FillScaleTriL(diag_shift=jnp.array(config.jitter)), - ] - ) - - # Default bijections - config.transformations = transformations = ConfigDict() - transformations.positive_transform = Softplus - transformations.identity_transform = Identity - transformations.triangular_transform = FillScaleTriL - - # Default parameter transforms - transformations.alpha = "positive_transform" - transformations.lengthscale = "positive_transform" - transformations.variance = "positive_transform" - transformations.smoothness = "positive_transform" - transformations.shift = "positive_transform" - transformations.obs_noise = "positive_transform" - transformations.latent = "identity_transform" - transformations.basis_fns = "identity_transform" - transformations.offset = "identity_transform" - transformations.inducing_inputs = "identity_transform" - transformations.variational_mean = "identity_transform" - transformations.variational_root_covariance = "triangular_transform" - transformations.natural_vector = "identity_transform" - transformations.natural_matrix = "identity_transform" - transformations.expectation_vector = "identity_transform" - transformations.expectation_matrix = "identity_transform" - - return config - - -# This function is created for testing purposes only -def get_global_config_if_exists() -> ConfigDict: - """Get the global config file used within GPJax if it is available. - - Returns: - ConfigDict: A `ConfigDict` describing parameter transforms and default values. - """ - global __config - return __config - - -def add_parameter(param_name: str, bijection: dx.Bijector) -> None: - """Add a parameter and its corresponding transform to GPJax's config file. - - Args: - param_name (str): The name of the parameter that is to be added. - bijection (dx.Bijector): The bijection that should be used to unconstrain the parameter's value. - """ - lookup_name = f"{param_name}_transform" - get_global_config() - __config.transformations[lookup_name] = bijection - __config.transformations[param_name] = lookup_name +from jaxutils import config + +Identity = config.Identity +Softplus = config.Softplus +reset_global_config = depreciate(config.reset_global_config) +get_global_config = depreciate(config.get_global_config) +get_default_config = depreciate(config.get_default_config) +update_x64_sensitive_settings = depreciate(config.update_x64_sensitive_settings) +get_global_config_if_exists = depreciate(config.get_global_config_if_exists) +add_parameter = depreciate(config.add_parameter) + +__all__ = [ + "Identity", + "Softplus", + "reset_global_config", + "get_global_config", + "get_default_config", + "update_x64_sensitive_settings", + "get_global_config_if_exists", + "set_global_config", +] diff --git a/gpjax/utils.py b/gpjax/utils.py index bb27829f..27dcb507 100644 --- a/gpjax/utils.py +++ b/gpjax/utils.py @@ -13,88 +13,20 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Dict, Tuple +import jaxutils +import deprecation -import jax +depreciate = deprecation.deprecated( + deprecated_in="0.5.6", + removed_in="0.6.0", + details="Use method from jaxutils.config instead.", +) -def concat_dictionaries(a: Dict, b: Dict) -> Dict: - """ - Append one dictionary below another. If duplicate keys exist, then the - key-value pair of the second supplied dictionary will be used. - - Args: - a (Dict): The first dictionary. - b (Dict): The second dictionary. - - Returns: - Dict: The merged dictionary. - """ - return {**a, **b} - - -def merge_dictionaries(base_dict: Dict, in_dict: Dict) -> Dict: - """ - This will return a complete dictionary based on the keys of the first - matrix. If the same key should exist in the second matrix, then the - key-value pair from the first dictionary will be overwritten. The purpose of - this is that the base_dict will be a complete dictionary of values such that - an incomplete second dictionary can be used to update specific key-value - pairs. - - Args: - base_dict (Dict): Complete dictionary of key-value pairs. - in_dict (Dict): Subset of key-values pairs such that values from this - dictionary will take precedent. - - Returns: - Dict: A dictionary with the same keys as the base_dict, but with - values from the in_dict. - """ - for k, _ in base_dict.items(): - if k in in_dict.keys(): - base_dict[k] = in_dict[k] - return base_dict - - -def sort_dictionary(base_dict: Dict) -> Dict: - """ - Sort a dictionary based on the dictionary's key values. - - Args: - base_dict (Dict): The dictionary to be sorted. - - Returns: - Dict: The dictionary sorted alphabetically on the dictionary's keys. - """ - return dict(sorted(base_dict.items())) - - -def dict_array_coercion(params: Dict) -> Tuple[Callable, Callable]: - """ - Construct the logic required to map a dictionary of parameters to an array - of parameters. The values of the dictionary can themselves be dictionaries; - the function should work recursively. - - Args: - params (Dict): The dictionary of parameters that we would like to map - into an array. - - Returns: - Tuple[Callable, Callable]: A pair of functions, the first of which maps - a dictionary to an array, and the second of which maps an array to a - dictionary. The remapped dictionary is equal in structure to the original - dictionary. - """ - flattened_pytree = jax.tree_util.tree_flatten(params) - - def dict_to_array(parameter_dict) -> jax.Array: - return jax.tree_util.tree_flatten(parameter_dict)[0] - - def array_to_dict(parameter_array) -> Dict: - return jax.tree_util.tree_unflatten(flattened_pytree[1], parameter_array) - - return dict_to_array, array_to_dict +concat_dictionaries = depreciate(jaxutils.dict.concat_dictionaries) +merge_dictionaries = depreciate(jaxutils.dict.merge_dictionaries) +sort_dictionary = depreciate(jaxutils.dict.sort_dictionary) +dict_array_coercion = depreciate(jaxutils.dict.dict_array_coercion) __all__ = [ From 6f69f226b050190637bacfb3b1e1b526830303ac Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Mon, 9 Jan 2023 08:59:42 +0000 Subject: [PATCH 5/5] Bump JaxKern --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d98ddd57..4a172d62 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ def get_versions(): "jaxlib>=0.4.1", "optax", "jaxutils>=0.0.6", - "jaxkern", + "jaxkern>=0.0.4", "distrax>=0.1.2", "tqdm>=4.0.0", "ml-collections==0.1.0",