Skip to content

Commit

Permalink
fix FillScaleTriL, change config, update uncollapsed_vi
Browse files Browse the repository at this point in the history
  • Loading branch information
patel-zeel committed Dec 18, 2022
1 parent c50d34d commit d6b668c
Show file tree
Hide file tree
Showing 14 changed files with 233 additions and 73 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,20 @@ Feel free to join our [Slack Channel](https://join.slack.com/t/gpjax/shared_invi
> - [**Custom kernels**](https://gpjax.readthedocs.io/en/latest/examples/kernels.html#Custom-Kernel)
> - [**UCI regression**](https://gpjax.readthedocs.io/en/latest/examples/yacht.html)
## Convertion between `.ipynb` and `.py`
Above examples are stored in [examples](examples) directory in the double percent (`py:percent`) format. Checkout [jupytext using-cli](https://jupytext.readthedocs.io/en/latest/using-cli.html) for more info.

* To convert `example.py` to `example.ipynb`, run:

```bash
jupytext --to notebook example.py
```

* To convert `example.ipynb` to `example.py`, run:

```bash
jupytext --to py:percent example.ipynb
```

# Simple example

Expand Down
6 changes: 3 additions & 3 deletions examples/tfp_integration.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ def array_mll(parameter_array):
log_hyper_prior_eval = evaluate_priors(params_dict, unconstrained_priors)

# Evaluate the log-likelihood probability kernel, log [p(y|f, θ) p(f| θ)]:
log_mll_eval = log_mll(gpx.constrain(params_dict, bijectors))
return log_mll_eval + log_hyper_prior_eval
log_mll_eval = log_mll(gpx.constrain(params_dict, bijectors))

return log_mll_eval + log_hyper_prior_eval

return array_mll

Expand Down
67 changes: 63 additions & 4 deletions examples/uncollapsed_vi.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.11.2
# jupytext_version: 1.14.4
# kernelspec:
# display_name: base
# display_name: gpjax
# language: python
# name: python3
# ---
Expand All @@ -27,7 +27,13 @@
from jax import jit
from jax.config import config

import tensorflow_probability.substrates.jax as tfp

tfb = tfp.bijectors

import distrax as dx
import gpjax as gpx
from gpjax.config import get_global_config, reset_global_config

# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -154,6 +160,7 @@
# Despite introducing inducing inputs into our model, inference can still be intractable with large datasets. To circumvent this, optimisation can be done using stochastic mini-batches.

# %%
reset_global_config()
parameter_state = gpx.initialise(svgp, key)
optimiser = ox.adam(learning_rate=0.01)

Expand All @@ -162,7 +169,7 @@
parameter_state=parameter_state,
train_data=D,
optax_optim=optimiser,
n_iters=4000,
n_iters=3000,
key=jr.PRNGKey(42),
batch_size=128,
)
Expand Down Expand Up @@ -191,9 +198,61 @@
]
plt.show()

# %% [markdown]
# ## Custom transformations
#
# To train a covariance matrix, `gpjax` uses `tfb.FillScaleTriL` transformation by default. `tfb.FillScaleTriL` fills a 1d vector into a lower triangular matrix and then applies `SoftPlus` transformation on the diagonal to satisfy the necessary conditions for a valid Cholesky matrix. Users can change this default transformation with another valid transformation of their choice. For example, `Square` transformation on the diagonal can also serve the purpose.

# %%
gpx_config = get_global_config()
transformations = gpx_config.transformations
jitter = gpx_config.jitter

triangular_transform = dx.Chain(
[tfb.FillScaleTriL(diag_bijector=tfb.Square(), diag_shift=jnp.array(jitter))]
)

transformations.update({"triangular_transform": triangular_transform})

# %%
parameter_state = gpx.initialise(svgp, key)
optimiser = ox.adam(learning_rate=0.01)

inference_state = gpx.fit_batches(
objective=negative_elbo,
parameter_state=parameter_state,
train_data=D,
optax_optim=optimiser,
n_iters=3000,
key=jr.PRNGKey(42),
batch_size=128,
)

learned_params, training_history = inference_state.unpack()

# %%
latent_dist = q(learned_params)(xtest)
predictive_dist = likelihood(learned_params, latent_dist)

meanf = predictive_dist.mean()
sigma = predictive_dist.stddev()

fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(x, y, "o", alpha=0.15, label="Training Data", color="tab:gray")
ax.plot(xtest, meanf, label="Posterior mean", color="tab:blue")
ax.fill_between(xtest.flatten(), meanf - sigma, meanf + sigma, alpha=0.3)
[
ax.axvline(x=z_i, color="black", alpha=0.3, linewidth=1)
for z_i in learned_params["variational_family"]["inducing_inputs"]
]
plt.show()

# %% [markdown]
# We can see that `Square` transformation is able to get relatively better fit compared to `Softplus` with the same number of iterations, but `Softplus` is recommended over `Square` for stability of optimization.

# %% [markdown]
# ## System configuration

# %%
# %reload_ext watermark
# %watermark -n -u -v -iv -w -a 'Thomas Pinder & Daniel Dodd'
# %watermark -n -u -v -iv -w -a 'Thomas Pinder, Daniel Dodd & Zeel B Patel'
6 changes: 4 additions & 2 deletions examples/yacht.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# ---

# %%
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -150,7 +149,10 @@
# With an optimal set of parameters learned, we can make predictions on the set of data that we held back right at the start. We'll do this in the usual way by first computing the latent function's distribution before computing the predictive posterior distribution.

# %%
latent_dist = posterior(learned_params, training_data, )(scaled_Xte)
latent_dist = posterior(
learned_params,
training_data,
)(scaled_Xte)
predictive_dist = likelihood(learned_params, latent_dist)

predictive_mean = predictive_dist.mean()
Expand Down
77 changes: 64 additions & 13 deletions gpjax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================

import jax
import distrax as dx
import jax.numpy as jnp
import jax.random as jr
Expand All @@ -21,35 +22,88 @@

__config = None

FillTriangular = dx.Chain(
[
tfb.FillTriangular(),
]
) # TODO: Dan to chain methods.
Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x)
Softplus = dx.Lambda(
forward=lambda x: jnp.log(1 + jnp.exp(x)),
inverse=lambda x: jnp.log(jnp.exp(x) - 1.0),
)


def get_defaults() -> ConfigDict:
"""Construct and globally register the config file used within GPJax.
def reset_global_config() -> None:
global __config
__config = get_default_config()


def get_global_config() -> ConfigDict:
"""Get the global config file used within GPJax.
Returns:
ConfigDict: A `ConfigDict` describing parameter transforms and default values.
"""
global __config

if __config is None:
__config = get_default_config()
return __config

# If the global config is available, check if the x64 state has changed
x64_state = jax.config.x64_enabled

# If the x64 state has not changed, return the existing global config
if x64_state is __config.x64_state:
return __config

# If the x64 state has changed, return the updated global config
update_x64_sensitive_settings()
return __config


def update_x64_sensitive_settings() -> None:
"""Update the global config if x64 state changes."""
global __config

# Update the x64 state
x64_state = jax.config.x64_enabled
__config.x64_state = x64_state

# Update the x64 sensitive bijectors
FillScaleTriL = dx.Chain(
[
tfb.FillScaleTriL(diag_shift=jnp.array(__config.jitter)),
]
) # TODO: Dan to chain methods.

transformations = __config.transformations
transformations.triangular_transform = FillScaleTriL


def get_default_config() -> ConfigDict:
"""Construct and return the default config file.
Returns:
ConfigDict: A `ConfigDict` describing parameter transforms and default values.
"""

config = ConfigDict(type_safe=False)
config.key = jr.PRNGKey(123)

# Set the x64 state
config.x64_state = jax.config.x64_enabled

# Covariance matrix stabilising jitter
config.jitter = 1e-6

FillScaleTriL = dx.Chain(
[
tfb.FillScaleTriL(diag_shift=jnp.array(config.jitter)),
]
) # TODO: Dan to chain methods.

# Default bijections
config.transformations = transformations = ConfigDict()
transformations.positive_transform = Softplus
transformations.identity_transform = Identity
transformations.triangular_transform = FillTriangular
transformations.triangular_transform = FillScaleTriL

# Default parameter transforms
transformations.alpha = "positive_transform"
Expand All @@ -69,10 +123,7 @@ def get_defaults() -> ConfigDict:
transformations.expectation_vector = "identity_transform"
transformations.expectation_matrix = "identity_transform"

global __config
if not __config:
__config = config
return __config
return config


def add_parameter(param_name: str, bijection: dx.Bijector) -> None:
Expand All @@ -83,6 +134,6 @@ def add_parameter(param_name: str, bijection: dx.Bijector) -> None:
bijection (dx.Bijector): The bijection that should be used to unconstrain the parameter's value.
"""
lookup_name = f"{param_name}_transform"
get_defaults()
get_global_config()
__config.transformations[lookup_name] = bijection
__config.transformations[param_name] = lookup_name
12 changes: 6 additions & 6 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from jaxlinop import identity

from .config import get_defaults
from .config import get_global_config
from .kernels import AbstractKernel
from .likelihoods import AbstractLikelihood, Conjugate, Gaussian, NonConjugate
from .mean_functions import AbstractMeanFunction, Zero
Expand Down Expand Up @@ -209,7 +209,7 @@ def predict(
should be evaluated at. The mean function's value at these points is
then returned.
"""
jitter = get_defaults()["jitter"]
jitter = get_global_config()["jitter"]

# Unpack mean function and kernel
mean_function = self.mean_function
Expand Down Expand Up @@ -392,7 +392,7 @@ def predict(
function that accepts an input array and returns the predictive
distribution as a ``GaussianDistribution``.
"""
jitter = get_defaults()["jitter"]
jitter = get_global_config()["jitter"]

# Unpack training data
x, y, n = train_data.X, train_data.y, train_data.n
Expand Down Expand Up @@ -511,7 +511,7 @@ def marginal_log_likelihood(
of the marginal log-likelihood that can be evaluated at a
given parameter set.
"""
jitter = get_defaults()["jitter"]
jitter = get_global_config()["jitter"]

# Unpack training data
x, y, n = train_data.X, train_data.y, train_data.n
Expand Down Expand Up @@ -623,7 +623,7 @@ def predict(
input array and returns the predictive distribution as
a ``dx.Distribution``.
"""
jitter = get_defaults()["jitter"]
jitter = get_global_config()["jitter"]

# Unpack training data
x, n = train_data.X, train_data.n
Expand Down Expand Up @@ -706,7 +706,7 @@ def marginal_log_likelihood(
of the marginal log-likelihood that can be evaluated at a given
parameter set.
"""
jitter = get_defaults()["jitter"]
jitter = get_global_config()["jitter"]

# Unpack dataset
x, y, n = train_data.X, train_data.y, train_data.n
Expand Down
4 changes: 0 additions & 4 deletions gpjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,10 @@
import jax
from jaxtyping import Array, Float

from .config import get_defaults
from chex import PRNGKey as PRNGKeyType
from jaxutils import PyTree


JITTER = get_defaults()["jitter"]


class AbstractKernelComputation(PyTree):
"""Abstract class for kernel computations."""

Expand Down
4 changes: 2 additions & 2 deletions gpjax/natural_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from jax import value_and_grad
from jaxtyping import Array, Float

from .config import get_defaults
from .config import get_global_config
from .gps import AbstractPosterior
from .parameters import build_trainables, constrain, trainable_params
from .types import Dataset
Expand All @@ -32,7 +32,7 @@
)
from .variational_inference import StochasticVI

DEFAULT_JITTER = get_defaults()["jitter"]
DEFAULT_JITTER = get_global_config()["jitter"]


def natural_to_expectation(params: Dict, jitter: float = DEFAULT_JITTER) -> Dict:
Expand Down
4 changes: 2 additions & 2 deletions gpjax/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from chex import dataclass, PRNGKey as PRNGKeyType
from jaxtyping import Array, Float

from .config import Identity, get_defaults
from .config import Identity, get_global_config
from .utils import merge_dictionaries


Expand Down Expand Up @@ -154,7 +154,7 @@ def build_bijectors(params: Dict) -> Dict:
Dict: A dictionary that maps each parameter to a bijection.
"""
bijectors = copy_dict_structure(params)
config = get_defaults()
config = get_global_config()
transform_set = config["transformations"]

def recursive_bijectors_list(ps, bs):
Expand Down
Loading

0 comments on commit d6b668c

Please sign in to comment.