Skip to content

Commit

Permalink
Merge pull request #178 from JaxGaussianProcesses/init_params
Browse files Browse the repository at this point in the history
` init_params` revamp, remove test from `./gpjax`
  • Loading branch information
thomaspinder authored Jan 9, 2023
2 parents 3bbc8cb + 6f69f22 commit c0809b4
Show file tree
Hide file tree
Showing 19 changed files with 218 additions and 476 deletions.
2 changes: 1 addition & 1 deletion docs/_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ process objects.
.. autoclass:: AbstractPrior
:members:
:special-members: __call__
:private-members: _initialise_params
:private-members: init_params
:exclude-members: from_tuple, replace, to_tuple

.. autoclass:: AbstractPosterior
Expand Down
4 changes: 2 additions & 2 deletions examples/graph_kernels.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
kernel = jk.GraphKernel(laplacian=L)
prior = gpx.Prior(kernel=kernel)

true_params = prior._initialise_params(key)
true_params = prior.init_params(key)
true_params["kernel"] = {
"lengthscale": jnp.array(2.3),
"variance": jnp.array(3.2),
Expand All @@ -101,7 +101,7 @@
kernel.compute_engine.gram

# %%
kernel.gram(params=kernel._initialise_params(key), inputs=x)
kernel.gram(params=kernel.init_params(key), inputs=x)

# %% [markdown]
#
Expand Down
8 changes: 6 additions & 2 deletions examples/haiku.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,16 @@ def __call__(

def initialise(self, dummy_x: Float[Array, "1 D"], key: jr.KeyArray) -> None:
nn_params = self.network.init(rng=key, x=dummy_x)
base_kernel_params = self.base_kernel._initialise_params(key)
base_kernel_params = self.base_kernel.init_params(key)
self._params = {**nn_params, **base_kernel_params}

def _initialise_params(self, key: jr.KeyArray) -> Dict:
def init_params(self, key: jr.KeyArray) -> Dict:
return self._params

# This is depreciated. Can be removed once JaxKern is updated.
def _initialise_params(self, key: jr.KeyArray) -> Dict:
return self.init_params(key)


# %% [markdown]
# ### Defining a network
Expand Down
26 changes: 15 additions & 11 deletions examples/kernels.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# display_name: base
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -97,7 +97,7 @@

# %%
print(f"ARD: {slice_kernel.ard}")
print(f"Lengthscales: {slice_kernel._initialise_params(key)['lengthscale']}")
print(f"Lengthscales: {slice_kernel.init_params(key)['lengthscale']}")

# %% [markdown]
# We'll now simulate some data and evaluate the kernel on the previously selected input dimensions.
Expand All @@ -107,7 +107,7 @@
x_matrix = jr.normal(key, shape=(50, 5))

# Default parameter dictionary
params = slice_kernel._initialise_params(key)
params = slice_kernel.init_params(key)

# Compute the Gram matrix
K = slice_kernel.gram(params, x_matrix)
Expand All @@ -127,9 +127,9 @@
sum_k = k1 + k2

fig, ax = plt.subplots(ncols=3, figsize=(20, 5))
im0 = ax[0].matshow(k1.gram(k1._initialise_params(key), x).to_dense())
im1 = ax[1].matshow(k2.gram(k2._initialise_params(key), x).to_dense())
im2 = ax[2].matshow(sum_k.gram(sum_k._initialise_params(key), x).to_dense())
im0 = ax[0].matshow(k1.gram(k1.init_params(key), x).to_dense())
im1 = ax[1].matshow(k2.gram(k2.init_params(key), x).to_dense())
im2 = ax[2].matshow(sum_k.gram(sum_k.init_params(key), x).to_dense())

fig.colorbar(im0, ax=ax[0])
fig.colorbar(im1, ax=ax[1])
Expand All @@ -144,10 +144,10 @@
prod_k = k1 * k2 * k3

fig, ax = plt.subplots(ncols=4, figsize=(20, 5))
im0 = ax[0].matshow(k1.gram(k1._initialise_params(key), x).to_dense())
im1 = ax[1].matshow(k2.gram(k2._initialise_params(key), x).to_dense())
im2 = ax[2].matshow(k3.gram(k3._initialise_params(key), x).to_dense())
im3 = ax[3].matshow(prod_k.gram(prod_k._initialise_params(key), x).to_dense())
im0 = ax[0].matshow(k1.gram(k1.init_params(key), x).to_dense())
im1 = ax[1].matshow(k2.gram(k2.init_params(key), x).to_dense())
im2 = ax[2].matshow(k3.gram(k3.init_params(key), x).to_dense())
im3 = ax[3].matshow(prod_k.gram(prod_k.init_params(key), x).to_dense())

fig.colorbar(im0, ax=ax[0])
fig.colorbar(im1, ax=ax[1])
Expand Down Expand Up @@ -218,9 +218,13 @@ def __call__(
K = (1 + tau * t / self.c) * jnp.clip(1 - t / self.c, 0, jnp.inf) ** tau
return K.squeeze()

def _initialise_params(self, key: jr.PRNGKey) -> dict:
def init_params(self, key: jr.KeyArray) -> dict:
return {"tau": jnp.array([4.0])}

# This is depreciated. Can be removed once JaxKern is updated.
def _initialise_params(self, key: jr.KeyArray) -> Dict:
return self.init_params(key)


# %% [markdown]
# We unpack this now to make better sense of it. In the kernel's `__init__`
Expand Down
157 changes: 26 additions & 131 deletions gpjax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,138 +13,33 @@
# limitations under the License.
# ==============================================================================

import jax
import distrax as dx
import jax.numpy as jnp
import jax.random as jr
import tensorflow_probability.substrates.jax.bijectors as tfb
from ml_collections import ConfigDict

__config = None
import deprecation

Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x)
Softplus = dx.Lambda(
forward=lambda x: jnp.log(1 + jnp.exp(x)),
inverse=lambda x: jnp.log(jnp.exp(x) - 1.0),
depreciate = deprecation.deprecated(
deprecated_in="0.5.6",
removed_in="0.6.0",
details="Use method from jaxutils.config instead.",
)


def reset_global_config() -> None:
global __config
__config = get_default_config()


def get_global_config() -> ConfigDict:
"""Get the global config file used within GPJax.
Returns:
ConfigDict: A `ConfigDict` describing parameter transforms and default values.
"""
global __config

if __config is None:
__config = get_default_config()
return __config

# If the global config is available, check if the x64 state has changed
x64_state = jax.config.x64_enabled

# If the x64 state has not changed, return the existing global config
if x64_state is __config.x64_state:
return __config

# If the x64 state has changed, return the updated global config
update_x64_sensitive_settings()
return __config


def update_x64_sensitive_settings() -> None:
"""Update the global config if x64 state changes."""
global __config

# Update the x64 state
x64_state = jax.config.x64_enabled
__config.x64_state = x64_state

# Update the x64 sensitive bijectors
FillScaleTriL = dx.Chain(
[
tfb.FillScaleTriL(diag_shift=jnp.array(__config.jitter)),
]
)

transformations = __config.transformations
transformations.triangular_transform = FillScaleTriL


def get_default_config() -> ConfigDict:
"""Construct and return the default config file.
Returns:
ConfigDict: A `ConfigDict` describing parameter transforms and default values.
"""

config = ConfigDict(type_safe=False)
config.key = jr.PRNGKey(123)

# Set the x64 state
config.x64_state = jax.config.x64_enabled

# Covariance matrix stabilising jitter
config.jitter = 1e-6

FillScaleTriL = dx.Chain(
[
tfb.FillScaleTriL(diag_shift=jnp.array(config.jitter)),
]
)

# Default bijections
config.transformations = transformations = ConfigDict()
transformations.positive_transform = Softplus
transformations.identity_transform = Identity
transformations.triangular_transform = FillScaleTriL

# Default parameter transforms
transformations.alpha = "positive_transform"
transformations.lengthscale = "positive_transform"
transformations.variance = "positive_transform"
transformations.smoothness = "positive_transform"
transformations.shift = "positive_transform"
transformations.obs_noise = "positive_transform"
transformations.latent = "identity_transform"
transformations.basis_fns = "identity_transform"
transformations.offset = "identity_transform"
transformations.inducing_inputs = "identity_transform"
transformations.variational_mean = "identity_transform"
transformations.variational_root_covariance = "triangular_transform"
transformations.natural_vector = "identity_transform"
transformations.natural_matrix = "identity_transform"
transformations.expectation_vector = "identity_transform"
transformations.expectation_matrix = "identity_transform"

return config


# This function is created for testing purposes only
def get_global_config_if_exists() -> ConfigDict:
"""Get the global config file used within GPJax if it is available.
Returns:
ConfigDict: A `ConfigDict` describing parameter transforms and default values.
"""
global __config
return __config


def add_parameter(param_name: str, bijection: dx.Bijector) -> None:
"""Add a parameter and its corresponding transform to GPJax's config file.
Args:
param_name (str): The name of the parameter that is to be added.
bijection (dx.Bijector): The bijection that should be used to unconstrain the parameter's value.
"""
lookup_name = f"{param_name}_transform"
get_global_config()
__config.transformations[lookup_name] = bijection
__config.transformations[param_name] = lookup_name
from jaxutils import config

Identity = config.Identity
Softplus = config.Softplus
reset_global_config = depreciate(config.reset_global_config)
get_global_config = depreciate(config.get_global_config)
get_default_config = depreciate(config.get_default_config)
update_x64_sensitive_settings = depreciate(config.update_x64_sensitive_settings)
get_global_config_if_exists = depreciate(config.get_global_config_if_exists)
add_parameter = depreciate(config.add_parameter)

__all__ = [
"Identity",
"Softplus",
"reset_global_config",
"get_global_config",
"get_default_config",
"update_x64_sensitive_settings",
"get_global_config_if_exists",
"set_global_config",
]
Loading

0 comments on commit c0809b4

Please sign in to comment.