diff --git a/.github/workflows/workflow-master.yml b/.github/workflows/workflow-master.yml index eeab66cc..da8acf75 100644 --- a/.github/workflows/workflow-master.yml +++ b/.github/workflows/workflow-master.yml @@ -5,7 +5,7 @@ on: - "master" pull_request: branches: - - "master" + - "*" jobs: codecov: name: Codecov Workflow diff --git a/docs/conf.py b/docs/conf.py index d84d66c6..f7e060ea 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -103,10 +103,12 @@ def find_version(*file_paths): bibtex_bibfiles = ["refs.bib"] bibtex_style = "unsrt" bibtex_reference_style = "author_year" +nb_execution_mode = "auto" nbsphinx_allow_errors = True nbsphinx_custom_formats = { ".pct.py": ["jupytext.reads", {"fmt": "py:percent"}], } +jupyter_execute_notebooks = "cache" # Latex commands # mathjax_path = "https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js" @@ -183,6 +185,11 @@ def find_version(*file_paths): "logo_only": True, "show_toc_level": 2, "repository_url": "https://github.com/thomaspinder/GPJax/", + "launch_buttons": { + "binderhub_url": "https://mybinder.org", + "colab_url": "https://colab.research.google.com", + "notebook_interface": "jupyterlab", + }, "use_repository_button": True, "use_sidenotes": True, # Turns footnotes into sidenotes - https://sphinx-book-theme.readthedocs.io/en/stable/content-blocks.html } @@ -199,10 +206,10 @@ def find_version(*file_paths): # } + # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ["_static"] - html_static_path = ["_static"] html_css_files = ["custom.css"] diff --git a/docs/index.md b/docs/index.md index a73af5aa..4c4fa24f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -42,7 +42,6 @@ To learn more, checkout the [regression notebook](https://gpjax.readthedocs.io/en/latest/examples/regression.html). ::: -Are you rendering --- diff --git a/docs/requirements.txt b/docs/requirements.txt index 9f284740..f8305d6e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,6 +6,7 @@ jinja2 nbsphinx>=0.8.0 nb-black==1.0.7 matplotlib==3.3.3 +tensorflow-probability>=0.16.0 sphinx-copybutton networkx>=2.0.0 pandoc diff --git a/examples/barycentres.pct.py b/examples/barycentres.pct.py index b3e0738d..1a3e3c21 100644 --- a/examples/barycentres.pct.py +++ b/examples/barycentres.pct.py @@ -7,7 +7,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.11.5 +# jupytext_version: 1.11.2 # kernelspec: # display_name: Python 3.9.7 ('gpjax') # language: python @@ -18,18 +18,19 @@ # # Gaussian Processes Barycentres # # In this notebook we'll give an implementation of . In this work, the existence of a Wasserstein barycentre between a collection of Gaussian processes is proven. When faced with trying to _average_ a set of probability distributions, the Wasserstein barycentre is an attractive choice as it enables uncertainty amongst the individual distributions to be incorporated into the averaged distribution. When compared to a naive _mean of means_ and _mean of variances_ approach to computing the average probability distributions, it can be seen that Wasserstein barycentres offer significantly more favourable uncertainty estimation. +# +# %% import typing as tp -import distrax as dx import jax import jax.numpy as jnp import jax.random as jr import jax.scipy.linalg as jsl import matplotlib.pyplot as plt +import numpyro import optax as ox -# %% import gpjax as gpx key = jr.PRNGKey(123) @@ -99,24 +100,23 @@ def fit_gp(x: jnp.DeviceArray, y: jnp.DeviceArray): D = gpx.Dataset(X=x, y=y) likelihood = gpx.Gaussian(num_datapoints=n) posterior = gpx.Prior(kernel=gpx.RBF()) * likelihood - params, trainables, constrainers, unconstrainers = gpx.initialise( + parameter_state = gpx.initialise( posterior, key - ).unpack() - params = gpx.transform(params, unconstrainers) + ) + params, trainables, bijectors = parameter_state.unpack() + params = gpx.unconstrain(params, bijectors) objective = jax.jit( - posterior.marginal_log_likelihood(D, constrainers, negative=True) + posterior.marginal_log_likelihood(D, negative=True) ) opt = ox.adam(learning_rate=0.01) learned_params, training_history = gpx.fit( objective=objective, - trainables=trainables, - params=params, + parameter_state=parameter_state, optax_optim=opt, n_iters=1000, ).unpack() - learned_params = gpx.transform(learned_params, constrainers) return likelihood(posterior(D, learned_params)(xtest), learned_params) @@ -134,9 +134,9 @@ def sqrtm(A: jnp.DeviceArray): def wasserstein_barycentres( - distributions: tp.List[dx.Distribution], weights: jnp.DeviceArray + distributions: tp.List[numpyro.distributions.Distribution], weights: jnp.DeviceArray ): - covariances = [d.covariance() for d in distributions] + covariances = [d.covariance_matrix for d in distributions] cov_stack = jnp.stack(covariances) stack_sqrt = jax.vmap(sqrtm)(cov_stack) @@ -156,7 +156,7 @@ def step(covariance_candidate: jnp.DeviceArray, i: jnp.DeviceArray): # %% weights = jnp.ones((n_datasets,)) / n_datasets -means = jnp.stack([d.mean() for d in posterior_preds]) +means = jnp.stack([d.mean for d in posterior_preds]) barycentre_mean = jnp.tensordot(weights, means, axes=1) step_fn = jax.jit(wasserstein_barycentres(posterior_preds, weights)) @@ -167,7 +167,7 @@ def step(covariance_candidate: jnp.DeviceArray, i: jnp.DeviceArray): ) -barycentre_process = dx.MultivariateNormalFullCovariance( +barycentre_process = numpyro.distributions.MultivariateNormal( barycentre_mean, barycentre_covariance ) @@ -179,16 +179,16 @@ def step(covariance_candidate: jnp.DeviceArray, i: jnp.DeviceArray): # %% def plot( - dist: dx.Distribution, + dist: numpyro.distributions.Distribution, ax, color: str = "tab:blue", label: str = None, ci_alpha: float = 0.2, linewidth: float = 1.0, ): - mu = dist.mean() - sigma = dist.stddev() - ax.plot(xtest, dist.mean(), linewidth=linewidth, color=color, label=label) + mu = dist.mean + sigma = jnp.sqrt(dist.covariance_matrix.diagonal()) + ax.plot(xtest, dist.mean, linewidth=linewidth, color=color, label=label) ax.fill_between( xtest.squeeze(), mu - sigma, mu + sigma, alpha=ci_alpha, color=color ) diff --git a/examples/classification.pct.py b/examples/classification.pct.py index 0b10a9b2..fc401e21 100644 --- a/examples/classification.pct.py +++ b/examples/classification.pct.py @@ -7,7 +7,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.11.5 +# jupytext_version: 1.11.2 # kernelspec: # display_name: Python 3.9.7 ('gpjax') # language: python @@ -19,18 +19,16 @@ # # In this notebook we demonstrate how to perform inference for Gaussian process models with non-Gaussian likelihoods via maximum a posteriori (MAP) and Markov chain Monte Carlo (MCMC). We focus on a classification task here and use [BlackJax](https://github.com/blackjax-devs/blackjax/) for sampling. -import blackjax -import distrax as dx - # %% import jax import jax.numpy as jnp import jax.random as jr import jax.scipy as jsp import matplotlib.pyplot as plt +import numpyro.distributions as npd import optax as ox from jaxtyping import Array, Float - +import blackjax import gpjax as gpx from gpjax.utils import I @@ -72,37 +70,33 @@ # To begin we obtain a set of initial parameter values through the `initialise` callable, and transform these to the unconstrained space via `transform` (see the [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)). We also define the negative marginal log-likelihood, and JIT compile this to accelerate training. # %% parameter_state = gpx.initialise(posterior) -params, trainable, constrainer, unconstrainer = parameter_state.unpack() -params = gpx.transform(params, unconstrainer) +params, trainable, bijectors = parameter_state.unpack() -mll = jax.jit(posterior.marginal_log_likelihood(D, constrainer, negative=True)) +mll = jax.jit(posterior.marginal_log_likelihood(D, negative=True)) # %% [markdown] # We can obtain a MAP estimate by optimising the marginal log-likelihood with Obtax's optimisers. # %% opt = ox.adam(learning_rate=0.01) -unconstrained_params, training_history = gpx.fit( +learned_params, training_history = gpx.fit( mll, - params, - trainable, + parameter_state, opt, n_iters=500, ).unpack() -negative_Hessian = jax.jacfwd(jax.jacrev(mll))(unconstrained_params)["latent"][ +negative_Hessian = jax.jacfwd(jax.jacrev(mll))(learned_params)["latent"][ "latent" ][:, 0, :, 0] - -map_estimate = gpx.transform(unconstrained_params, constrainer) # %% [markdown] # From which we can make predictions at novel inputs, as illustrated below. # %% -latent_dist = posterior(D, map_estimate)(xtest) +latent_dist = posterior(D, learned_params)(xtest) -predictive_dist = likelihood(latent_dist, map_estimate) +predictive_dist = likelihood(latent_dist, learned_params) -predictive_mean = predictive_dist.mean() -predictive_std = predictive_dist.stddev() +predictive_mean = predictive_dist.mean +predictive_std = jnp.sqrt(predictive_dist.variance) fig, ax = plt.subplots(figsize=(12, 5)) ax.plot(x, y, "o", label="Observations", color="tab:red") @@ -139,7 +133,7 @@ # The Laplace approximation improves uncertainty quantification by incorporating curvature induced by the marginal log-likelihood's Hessian to construct an approximate Gaussian distribution centered on the MAP estimate. # Since the negative Hessian is positive definite, we can use the Cholesky decomposition to obtain the covariance matrix of the Laplace approximation at the datapoints below. # %% -f_map_estimate = posterior(D, map_estimate)(x).mean() +f_map_estimate = posterior(D, learned_params)(x).mean jitter = 1e-6 @@ -150,7 +144,7 @@ L_inv = jsp.linalg.solve_triangular(L, I(D.n), lower=True) H_inv = jsp.linalg.solve_triangular(L.T, L_inv, lower=False) -laplace_approximation = dx.MultivariateNormalFullCovariance(f_map_estimate, H_inv) +laplace_approximation = npd.MultivariateNormal(f_map_estimate, H_inv) from gpjax.kernels import cross_covariance, gram @@ -161,26 +155,26 @@ def predict( - laplace_at_data: dx.Distribution, + laplace_at_data: npd.Distribution, train_data: Dataset, test_inputs: Float[Array, "N D"], jitter: int = 1e-6, -) -> dx.Distribution: +) -> npd.Distribution: """Compute the predictive distribution of the Laplace approximation at novel inputs. Args: laplace_at_data (dict): The Laplace approximation at the datapoints. Returns: - dx.Distribution: The Laplace approximation at novel inputs. + npd.Distribution: The Laplace approximation at novel inputs. """ x, n = train_data.X, train_data.n t = test_inputs n_test = t.shape[0] - mu = laplace_at_data.mean().reshape(-1, 1) - cov = laplace_at_data.covariance() + mu = laplace_at_data.mean.reshape(-1, 1) + cov = laplace_at_data.covariance_matrix Ktt = gram(prior.kernel, t, params["kernel"]) Kxx = gram(prior.kernel, x, params["kernel"]) @@ -214,7 +208,7 @@ def predict( ) covariance += I(n_test) * jitter - return dx.MultivariateNormalFullCovariance( + return npd.MultivariateNormal( jnp.atleast_1d(mean.squeeze()), covariance ) @@ -224,10 +218,10 @@ def predict( # %% latent_dist = predict(laplace_approximation, D, xtest) -predictive_dist = likelihood(latent_dist, map_estimate) +predictive_dist = likelihood(latent_dist, learned_params) -predictive_mean = predictive_dist.mean() -predictive_std = predictive_dist.stddev() +predictive_mean = predictive_dist.mean +predictive_std = predictive_dist.variance**0.5 fig, ax = plt.subplots(figsize=(12, 5)) ax.plot(x, y, "o", label="Observations", color="tab:red") @@ -274,9 +268,9 @@ def predict( # %% # Adapted from BlackJax's introduction notebook. num_adapt = 500 -num_samples = 500 +num_samples = 200 -mll = jax.jit(posterior.marginal_log_likelihood(D, constrainer, negative=False)) +mll = jax.jit(posterior.marginal_log_likelihood(D, negative=False)) adapt = blackjax.window_adaptation( blackjax.nuts, mll, num_adapt, target_acceptance_rate=0.65 @@ -334,11 +328,10 @@ def one_step(state, rng_key): ps["kernel"]["lengthscale"] = states.position["kernel"]["lengthscale"][i] ps["kernel"]["variance"] = states.position["kernel"]["variance"][i] ps["latent"] = states.position["latent"][i, :, :] - ps = gpx.transform(ps, constrainer) latent_dist = posterior(D, ps)(xtest) predictive_dist = likelihood(latent_dist, ps) - samples.append(predictive_dist.sample(seed=key, sample_shape=(10,))) + samples.append(predictive_dist.sample(key, sample_shape=(10,))) samples = jnp.vstack(samples) diff --git a/examples/collapsed_vi.pct.py b/examples/collapsed_vi.pct.py index f80431b2..e8b3ce15 100644 --- a/examples/collapsed_vi.pct.py +++ b/examples/collapsed_vi.pct.py @@ -6,7 +6,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.11.5 +# jupytext_version: 1.11.2 # kernelspec: # display_name: Python 3.9.7 ('gpjax') # language: python @@ -86,22 +86,21 @@ # %% [markdown] # We now train our model akin to a Gaussian process regression model via the `fit` abstraction. Unlike the regression example given in the [conjugate regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html), the inducing locations that induce our variational posterior distribution are now part of the model's parameters. Using a gradient-based optimiser, we can then _optimise_ their location such that the evidence lower bound is maximised. # %% -params, trainables, constrainers, unconstrainers = gpx.initialise(sgpr, key).unpack() +parameter_state = gpx.initialise(sgpr, key) -loss_fn = jit(sgpr.elbo(D, constrainers, negative=True)) +loss_fn = jit(sgpr.elbo(D, negative=True)) optimiser = ox.adam(learning_rate=0.005) -params = gpx.transform(params, unconstrainers) - learned_params, training_history = gpx.fit( objective=loss_fn, - params=params, - trainables=trainables, + parameter_state=parameter_state, optax_optim=optimiser, n_iters=2000, ).unpack() -learned_params = gpx.transform(learned_params, constrainers) + +# %% +parameter_state.bijectors # %% [markdown] # We show predictions of our model with the learned inducing points overlayed in grey. @@ -109,10 +108,10 @@ latent_dist = q.predict(D, learned_params)(xtest) predictive_dist = likelihood(latent_dist, learned_params) -samples = latent_dist.sample(seed=key, sample_shape=20) +samples = latent_dist.sample(key, sample_shape=(20, )) -predictive_mean = predictive_dist.mean() -predictive_std = predictive_dist.stddev() +predictive_mean = predictive_dist.mean +predictive_std = jnp.sqrt(predictive_dist.variance) fig, ax = plt.subplots(figsize=(12, 5)) @@ -166,16 +165,15 @@ # %% full_rank_model = gpx.Prior(kernel=gpx.RBF()) * gpx.Gaussian(num_datapoints=D.n) -fr_params, fr_trainables, fr_constrainers, fr_unconstrainers = gpx.initialise( +fr_params, fr_trainables, fr_bijectors = gpx.initialise( full_rank_model, key ).unpack() -fr_params = gpx.transform(fr_params, fr_unconstrainers) -mll = jit(full_rank_model.marginal_log_likelihood(D, fr_constrainers, negative=True)) -# %timeit mll(fr_params).block_until_ready() +mll = jit(full_rank_model.marginal_log_likelihood(D, negative=True)) +%timeit mll(fr_params).block_until_ready() # %% -sparse_elbo = jit(sgpr.elbo(D, constrainers, negative=True)) -# %timeit sparse_elbo(params).block_until_ready() +sparse_elbo = jit(sgpr.elbo(D, negative=True)) +%timeit sparse_elbo(params).block_until_ready() # %% [markdown] # As we can see, the sparse approximation given here is around 50 times faster when compared against a full-rank model. diff --git a/examples/graph_kernels.pct.py b/examples/graph_kernels.pct.py index 0069ad4e..656983fb 100644 --- a/examples/graph_kernels.pct.py +++ b/examples/graph_kernels.pct.py @@ -7,7 +7,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.11.5 +# jupytext_version: 1.11.2 # kernelspec: # display_name: Python 3.9.7 ('gpjax') # language: python @@ -88,7 +88,7 @@ } fx = f(true_params)(x) -y = fx.sample(seed=key).reshape(-1, 1) +y = fx.sample(key).reshape(-1, 1) D = gpx.Dataset(X=x, y=y) @@ -145,8 +145,8 @@ initial_dist = likelihood(posterior(D, params)(x), params) predictive_dist = likelihood(posterior(D, learned_params)(x), learned_params) -initial_mean = initial_dist.mean() -learned_mean = predictive_dist.mean() +initial_mean = initial_dist.mean +learned_mean = predictive_dist.mean rmse = lambda ytrue, ypred: jnp.sum(jnp.sqrt(jnp.square(ytrue - ypred))) diff --git a/examples/haiku.pct.py b/examples/haiku.pct.py index b2434433..42d8820b 100644 --- a/examples/haiku.pct.py +++ b/examples/haiku.pct.py @@ -7,7 +7,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.11.5 +# jupytext_version: 1.11.2 # kernelspec: # display_name: Python 3.9.7 ('gpjax') # language: python @@ -19,9 +19,9 @@ # # In this notebook we demonstrate how GPJax can be used in conjunction with [Haiku](https://github.com/deepmind/dm-haiku) to build deep kernel Gaussian processes. Modelling data with discontinuities is a challenging task for regular Gaussian process models. However, as shown in , transforming the inputs to our Gaussian process model's kernel through a neural network can offer a solution to this. +# %% import typing as tp -import distrax as dx import haiku as hk import jax import jax.numpy as jnp @@ -31,7 +31,6 @@ from chex import dataclass from scipy.signal import sawtooth -# %% import gpjax as gpx from gpjax.kernels import Kernel @@ -179,8 +178,8 @@ def forward(x): latent_dist = posterior(D, final_params)(xtest) predictive_dist = likelihood(latent_dist, final_params) -predictive_mean = predictive_dist.mean() -predictive_std = predictive_dist.stddev() +predictive_mean = predictive_dist.mean +predictive_std = jnp.sqrt(predictive_dist.variance) fig, ax = plt.subplots(figsize=(12, 5)) ax.plot(x, y, "o", label="Observations", color="tab:red") diff --git a/examples/intro_to_gps.pct.py b/examples/intro_to_gps.pct.py index dc758e99..76f487e7 100644 --- a/examples/intro_to_gps.pct.py +++ b/examples/intro_to_gps.pct.py @@ -1,11 +1,12 @@ # --- # jupyter: # jupytext: +# custom_cell_magics: kql # text_representation: # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.11.5 +# jupytext_version: 1.11.2 # kernelspec: # display_name: Python 3.9.7 ('gpjax') # language: python @@ -105,15 +106,19 @@ # $$ # We can plot three different parameterisations of this density. -import distrax as dx - # %% +import numpyro.distributions as npd import jax.numpy as jnp import matplotlib.pyplot as plt +from utils import confidence_ellipse +import jax.random as jr +import pandas as pd +import seaborn as sns + -ud1 = dx.Normal(0.0, 1.0) -ud2 = dx.Normal(-1.0, 0.5) -ud3 = dx.Normal(0.25, 1.5) +ud1 = npd.Normal(0.0, 1.0) +ud2 = npd.Normal(-1.0, 0.5) +ud3 = npd.Normal(0.25, 1.5) xs = jnp.linspace(-5.0, 5.0, 500) @@ -121,10 +126,10 @@ for d in [ud1, ud2, ud3]: ax.plot( xs, - d.prob(xs), - label=f"$\mathcal{{N}}({{{float(d.mean())}}},\ {{{float(d.stddev())}}}^2)$", + jnp.exp(d.log_prob(xs)), + label=f"$\mathcal{{N}}({{{float(d.mean)}}},\ {{{float(jnp.sqrt(d.variance))}}}^2)$", ) - ax.fill_between(xs, jnp.zeros_like(xs), d.prob(xs), alpha=0.2) + ax.fill_between(xs, jnp.zeros_like(xs), jnp.exp(d.log_prob(xs)), alpha=0.2) ax.legend(loc="best") # %% [markdown] @@ -151,19 +156,15 @@ # $$ # Three example parameterisations of this can be visualised below where $\rho$ determines the correlation of the multivariate Gaussian. -import jax.random as jr - # %% -from utils import confidence_ellipse - key = jr.PRNGKey(123) -d1 = dx.MultivariateNormalDiag(jnp.zeros(2), scale_diag=jnp.array([1.0, 1.0])) -d2 = dx.MultivariateNormalTri( - jnp.zeros(2), jnp.linalg.cholesky(jnp.array([[1.0, 0.9], [0.9, 1.0]])) +d1 = npd.MultivariateNormal(jnp.zeros(2), jnp.eye(2)) +d2 = npd.MultivariateNormal( + jnp.zeros(2), scale_tril=jnp.linalg.cholesky(jnp.array([[1.0, 0.9], [0.9, 1.0]])) ) -d3 = dx.MultivariateNormalTri( - jnp.zeros(2), jnp.linalg.cholesky(jnp.array([[1.0, -0.5], [-0.5, 1.0]])) +d3 = npd.MultivariateNormal( + jnp.zeros(2), scale_tril=jnp.linalg.cholesky(jnp.array([[1.0, -0.5], [-0.5, 1.0]])) ) dists = [d1, d2, d3] @@ -181,15 +182,15 @@ titles = [r"$\rho = 0$", r"$\rho = 0.9$", r"$\rho = -0.5$"] for a, t, d in zip([ax0, ax1, ax2], titles, dists): - d_prob = d.prob(jnp.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])).reshape( + d_prob = jnp.exp(d.log_prob(jnp.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])).reshape( xx.shape - ) - cntf = a.contourf(xx, yy, d_prob, levels=20, antialiased=True, cmap="Reds") + )) + cntf = a.contourf(xx, yy, jnp.exp(d_prob), levels=20, antialiased=True, cmap="Reds") for c in cntf.collections: c.set_edgecolor("face") a.set_xlim(-2.75, 2.75) a.set_ylim(-2.75, 2.75) - samples = d.sample(seed=key, sample_shape=(5000,)) + samples = d.sample(key, sample_shape=(5000,)) xsample, ysample = samples[:, 0], samples[:, 1] confidence_ellipse( xsample, ysample, a, edgecolor="#3f3f3f", n_std=1.0, linestyle="--", alpha=0.8 @@ -236,19 +237,15 @@ # joint distribution $p(\mathbf{x}, \mathbf{y})$ quantifies the probability of two events, one # from $p(\mathbf{x})$ and another from $p(\mathbf{y})$, occurring at the same time. We visualise this idea below. -import pandas as pd - # %% -import seaborn as sns - n = 1000 -x = dx.Normal(loc=0.0, scale=1.0).sample(seed=key, sample_shape=(n,)) +x = npd.Normal(loc=0.0, scale=1.0).sample(key, sample_shape=(n,)) key, subkey = jr.split(key) -y = dx.Normal(loc=0.25, scale=0.5).sample(seed=subkey, sample_shape=(n,)) +y = npd.Normal(loc=0.25, scale=0.5).sample(subkey, sample_shape=(n,)) key, subkey = jr.split(subkey) -xfull = dx.Normal(loc=0.0, scale=1.0).sample(seed=subkey, sample_shape=(n * 10,)) +xfull = npd.Normal(loc=0.0, scale=1.0).sample(subkey, sample_shape=(n * 10,)) key, subkey = jr.split(subkey) -yfull = dx.Normal(loc=0.25, scale=0.5).sample(seed=subkey, sample_shape=(n * 10,)) +yfull = npd.Normal(loc=0.25, scale=0.5).sample(subkey, sample_shape=(n * 10,)) key, subkey = jr.split(subkey) df = pd.DataFrame({"x": x, "y": y, "idx": jnp.ones(n)}) diff --git a/examples/kernels.pct.py b/examples/kernels.pct.py index adfea990..ebb46ec5 100644 --- a/examples/kernels.pct.py +++ b/examples/kernels.pct.py @@ -7,7 +7,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.11.5 +# jupytext_version: 1.11.2 # kernelspec: # display_name: Python 3.9.7 ('gpjax') # language: python @@ -19,17 +19,15 @@ # # In this guide, we introduce the kernels available in GPJax and demonstrate how to create custom ones. -import distrax as dx -import jax +# %% import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt -import tensorflow_probability.substrates.jax.bijectors as tfb from jax import jit from jaxtyping import Array, Float from optax import adam +import numpyro.distributions as npd -# %% import gpjax as gpx key = jr.PRNGKey(123) @@ -62,9 +60,9 @@ for k, ax in zip(kernels, axes.ravel()): prior = gpx.Prior(kernel=k) - params, _, _, _ = gpx.initialise(prior, key).unpack() + params, _, _ = gpx.initialise(prior, key).unpack() rv = prior(params)(x) - y = rv.sample(sample_shape=10, seed=key) + y = rv.sample(key, sample_shape=(10,)) ax.plot(x, y.T, alpha=0.7) ax.set_title(k.name) @@ -207,14 +205,31 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict: # # To define a bijector here we'll make use of the `Lambda` operator given in Distrax. This lets us convert any regular Jax function into a bijection. Given that we require $\tau$ to be strictly greater than $4.$, we'll apply a [softplus transformation](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softplus.html) where the lower bound is shifted by $4.$. - # %% from gpjax.config import add_parameter +from jax.nn import softplus + + +class ShiftSoftplus(npd.transforms.Transform): + domain = npd.constraints.real + codomain = npd.constraints.real + + def __init__(self, low: Float[Array, "1"]) -> None: + super().__init__() + self.low = jnp.array(low) + + def __call__(self, x): + x -= self.low + return softplus(x) + self.low -bij_fn = lambda x: jax.nn.softplus(x + jnp.array(4.0)) -bij = dx.Lambda(bij_fn) + def _inverse(self, y): + return jnp.log(-jnp.expm1(-y)) + y -add_parameter("tau", bij) + def log_abs_det_jacobian(self, x, y, intermediates=None): + return -softplus(-x) + + +add_parameter("tau", ShiftSoftplus(4.)) # %% [markdown] # ### Using our polar kernel @@ -239,32 +254,27 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict: circlular_posterior = gpx.Prior(kernel=PKern) * likelihood # Initialise parameters and corresponding transformations -params, trainable, constrainer, unconstrainer = gpx.initialise( - circlular_posterior, key -).unpack() +parameter_state = gpx.initialise(circlular_posterior, key) # Optimise GP's marginal log-likelihood using Adam -mll = jit(circlular_posterior.marginal_log_likelihood(D, constrainer, negative=True)) +mll = jit(circlular_posterior.marginal_log_likelihood(D, negative=True)) + learned_params, training_history = gpx.fit( mll, - params, - trainable, - adam(learning_rate=0.05), + parameter_state, + adam(learning_rate=0.01), n_iters=1000, ).unpack() -# Untransform learned parameters -final_params = gpx.transform(learned_params, constrainer) - # %% [markdown] # ### Prediction # # We'll now query the GP's predictive posterior at linearly spaced novel inputs and illustrate the results. # %% -posterior_rv = likelihood(circlular_posterior(D, final_params)(angles), final_params) -mu = posterior_rv.mean() -one_sigma = posterior_rv.stddev() +posterior_rv = likelihood(circlular_posterior(D, learned_params)(angles), learned_params) +mu = posterior_rv.mean +one_sigma = jnp.sqrt(posterior_rv.variance) # %% fig = plt.figure(figsize=(10, 8)) diff --git a/examples/regression.pct.py b/examples/regression.pct.py index 62f319e4..1e976025 100644 --- a/examples/regression.pct.py +++ b/examples/regression.pct.py @@ -6,7 +6,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.11.5 +# jupytext_version: 1.11.2 # kernelspec: # display_name: Python 3.9.7 ('gpjax') # language: python @@ -86,9 +86,9 @@ parameter_state = gpx.initialise(prior, key) prior_dist = prior(parameter_state.params)(xtest) -prior_mean = prior_dist.mean() -prior_std = jnp.sqrt(prior_dist.covariance().diagonal()) -samples = prior_dist.sample(seed=key, sample_shape=20).T +prior_mean = prior_dist.mean +prior_std = jnp.sqrt(prior_dist.covariance_matrix.diagonal()) +samples = prior_dist.sample(key, sample_shape=(20,)).T plt.plot(xtest, samples, color="tab:blue", alpha=0.5) plt.plot(xtest, prior_mean, color="tab:orange") @@ -146,20 +146,21 @@ # We can now unpack the `ParameterState` to receive each of the four components listed above. # %% -params, trainable, constrainer, unconstrainer = parameter_state.unpack() +params, trainable, bijectors = parameter_state.unpack() pp.pprint(params) # %% [markdown] # To motivate the purpose of `constrainer` and `unconstrainer` more precisely, notice that our model hyperparameters $\{\ell^2, \sigma^2, \alpha^2 \}$ are all strictly positive. To ensure more stable optimisation, it is strongly advised to transform the parameters onto an unconstrained space first via `transform`. # %% -params = gpx.transform(params, unconstrainer) +# params = gpx.constrain(params, bijectors) +pp.pprint(params) # %% [markdown] # To train our hyperparameters, we optimising the marginal log-likelihood of the posterior with respect to them. We define the marginal log-likelihood with `marginal_log_likelihood` on the posterior. # %% -mll = jit(posterior.marginal_log_likelihood(D, constrainer, negative=True)) +mll = jit(posterior.marginal_log_likelihood(D, negative=True)) mll(params) # %% [markdown] # Since most optimisers (including here) minimise a given function, we have realised the negative marginal log-likelihood and just-in-time (JIT) compiled this to accelerate training. @@ -170,10 +171,9 @@ # %% opt = ox.adam(learning_rate=0.01) inference_state = gpx.fit( - mll, - params, - trainable, - opt, + objective = mll, + parameter_state = parameter_state, + optax_optim = opt, n_iters=500, ) @@ -182,14 +182,7 @@ # %% final_params, training_history = inference_state.unpack() - -# %% [markdown] -# -# The exact value of our learned parameters is often useful in answering certain questions about the underlying process. To obtain these values, we untransfom our trained unconstrained parameters back to their original constrained space with `transform` and `constrainer`. - -# %% -final_params = gpx.transform(final_params, constrainer) -pp.pprint(final_params) +final_params # %% [markdown] # ## Prediction @@ -200,8 +193,8 @@ latent_dist = posterior(D, final_params)(xtest) predictive_dist = likelihood(latent_dist, final_params) -predictive_mean = predictive_dist.mean() -predictive_std = predictive_dist.stddev() +predictive_mean = predictive_dist.mean +predictive_std = jnp.sqrt(predictive_dist.covariance_matrix.diagonal()) # %% [markdown] # With the predictions and their uncertainty acquired, we illustrate the GP's performance at explaining the data $\mathcal{D}$ and recovering the underlying latent function of interest. diff --git a/examples/tfp_integration.pct.py b/examples/tfp_integration.pct.py index ba26c468..7e6534a6 100644 --- a/examples/tfp_integration.pct.py +++ b/examples/tfp_integration.pct.py @@ -7,7 +7,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.11.5 +# jupytext_version: 1.11.2 # kernelspec: # display_name: Python 3.9.7 ('gpjax') # language: python @@ -202,8 +202,8 @@ def run_chain(key, state): predictive_dist = likelihood(posterior(D, learned_params)(xtest), learned_params) -mu = predictive_dist.mean() -sigma = predictive_dist.stddev() +mu = predictive_dist.mean +sigma = jnp.sqrt(predictive_dist.variance) # %% [markdown] # Finally, we plot the learned posterior predictive distribution evaluated at the test points defined above. diff --git a/examples/uncollapsed_vi.pct.py b/examples/uncollapsed_vi.pct.py index 6ad8d60a..2133c4b9 100644 --- a/examples/uncollapsed_vi.pct.py +++ b/examples/uncollapsed_vi.pct.py @@ -7,9 +7,9 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.11.5 +# jupytext_version: 1.11.2 # kernelspec: -# display_name: Python 3.10.0 ('base') +# display_name: Python 3.9.7 ('gpjax') # language: python # name: python3 # --- @@ -137,10 +137,13 @@ # # Since Optax's optimisers work to minimise functions, to maximise the ELBO we return its negative. # %% -params, trainables, constrainers, unconstrainers = gpx.initialise(svgp, key).unpack() -params = gpx.transform(params, unconstrainers) +parameter_state = gpx.initialise(svgp, key) +params, trainables, bijectors = parameter_state.unpack() +loss_fn = svgp.elbo(D, negative=True) -loss_fn = jit(svgp.elbo(D, constrainers, negative=True)) +# %% +b = gpx.abstractions.get_batch(D, 64, key) +loss_fn(params, b) # %% [markdown] # ### Mini-batching @@ -152,8 +155,7 @@ inference_state = gpx.fit_batches( objective=loss_fn, - params=params, - trainables=trainables, + parameter_state=parameter_state, train_data=D, optax_optim=optimiser, n_iters=4000, @@ -161,7 +163,6 @@ batch_size=128, ) learned_params, training_history = inference_state.unpack() -learned_params = gpx.transform(learned_params, constrainers) # %% [markdown] # ## Predictions # @@ -172,8 +173,8 @@ latent_dist = q(learned_params)(xtest) predictive_dist = likelihood(latent_dist, learned_params) -meanf = predictive_dist.mean() -sigma = predictive_dist.stddev() +meanf = predictive_dist.mean +sigma = predictive_dist.variance ** 0.5 fig, ax = plt.subplots(figsize=(12, 5)) ax.plot(x, y, "o", alpha=0.15, label="Training Data", color="tab:gray") diff --git a/examples/utils.py b/examples/utils.py index aa287445..b8bdecd1 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -1,7 +1,5 @@ import matplotlib.transforms as transforms import numpy as np -import tensorflow as tf -import utils as ut from matplotlib.patches import Ellipse diff --git a/examples/yacht.pct.py b/examples/yacht.pct.py index bb5683a1..b549ff27 100644 --- a/examples/yacht.pct.py +++ b/examples/yacht.pct.py @@ -1,11 +1,12 @@ # --- # jupyter: # jupytext: +# custom_cell_magics: kql # text_representation: # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.11.5 +# jupytext_version: 1.11.2 # kernelspec: # display_name: Python 3.9.7 ('gpjax') # language: python @@ -17,13 +18,12 @@ # # In this notebook we'll demonstrate the use of GPJax on a benchmark UCI regression problem. Such tasks are commonly used within the research community to benchmark and evaluate new techniques against those already present in the literature. Much of the code contained in this notebook can be adapted to applied problems concerning datasets other than the one presented here. +# %% import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt import numpy as np import optax as ox - -# %% import pandas as pd from jax import jit from sklearn.metrics import mean_squared_error, r2_score @@ -153,8 +153,8 @@ latent_dist = posterior(training_data, learned_params)(scaled_Xte) predictive_dist = likelihood(latent_dist, learned_params) -predictive_mean = predictive_dist.mean() -predictive_stddev = predictive_dist.stddev() +predictive_mean = predictive_dist.mean +predictive_stddev = jnp.sqrt(predictive_dist.variance) # %% [markdown] # ## Evaluation diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 661c5cb8..2e83fc69 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -47,7 +47,7 @@ __description__ = "Didactic Gaussian processes in JAX" __url__ = "https://github.com/thomaspinder/GPJax" __contributors__ = "https://github.com/thomaspinder/GPJax/graphs/contributors" -__version__ = "0.4.13" +__version__ = "0.5.0" __all__ = [ diff --git a/gpjax/config.py b/gpjax/config.py index 5f6dcdac..88592bdc 100644 --- a/gpjax/config.py +++ b/gpjax/config.py @@ -1,21 +1,9 @@ -import distrax as dx -import jax.numpy as jnp import jax.random as jr -import tensorflow_probability.substrates.jax.bijectors as tfb +import numpyro.distributions as npd from ml_collections import ConfigDict __config = None -Softplus = dx.Lambda( - forward=lambda x: jnp.log(1 + jnp.exp(x)), - inverse=lambda x: jnp.log(jnp.exp(x) - 1.0), -) - -# TODO: Remove this once 'FillTriangular' is added to Distrax. -FillTriangular = dx.Chain([tfb.FillTriangular()]) - -Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x) - def get_defaults() -> ConfigDict: """Construct and globally register the config file used within GPJax. @@ -31,9 +19,9 @@ def get_defaults() -> ConfigDict: # Default bijections config.transformations = transformations = ConfigDict() - transformations.positive_transform = Softplus - transformations.identity_transform = Identity - transformations.triangular_transform = FillTriangular + transformations.positive_transform = npd.transforms.SoftplusTransform() + transformations.identity_transform = npd.transforms.IdentityTransform() + transformations.triangular_transform = npd.transforms.CholeskyTransform() # Default parameter transforms transformations.lengthscale = "positive_transform" @@ -58,7 +46,7 @@ def get_defaults() -> ConfigDict: return __config -def add_parameter(param_name: str, bijection: dx.Bijector) -> None: +def add_parameter(param_name: str, bijection: npd.transforms.Transform) -> None: """Add a parameter and its corresponding transform to GPJax's config file. Args: diff --git a/gpjax/gps.py b/gpjax/gps.py index 1bf3fcae..1e3642af 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -16,10 +16,10 @@ from abc import abstractmethod from typing import Any, Callable, Dict, Optional -import distrax as dx import jax.numpy as jnp import jax.random as jr import jax.scipy as jsp +import numpyro.distributions as npd from chex import dataclass from jaxtyping import Array, Float @@ -39,7 +39,7 @@ class AbstractGP: """Abstract Gaussian process object.""" - def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: + def __call__(self, *args: Any, **kwargs: Any) -> npd.Distribution: """Evaluate the Gaussian process at the given points. Args: @@ -47,12 +47,12 @@ def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: **kwargs (Any): The keyword arguments to pass to the GP's `predict` method. Returns: - dx.Distribution: A multivariate normal random variable representation of the Gaussian process. + npd.Distribution: A multivariate normal random variable representation of the Gaussian process. """ return self.predict(*args, **kwargs) @abstractmethod - def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: + def predict(self, *args: Any, **kwargs: Any) -> npd.Distribution: """Compute the latent function's multivariate normal distribution. Args: @@ -110,7 +110,7 @@ def __rmul__(self, other: AbstractLikelihood): """ return self.__mul__(other) - def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distribution]: + def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], npd.Distribution]: """Compute the GP's prior mean and variance. Args: @@ -120,16 +120,13 @@ def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distributi Callable[[Float[Array, "N D"]], dx.Distribution]: A mean function that accepts an input array for where the mean function should be evaluated at. The mean function's value at these points is then returned. """ - def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: + def predict_fn(test_inputs: Float[Array, "N D"]) -> npd.Distribution: t = test_inputs n_test = t.shape[0] μt = self.mean_function(t, params["mean_function"]) Ktt = gram(self.kernel, t, params["kernel"]) Ktt += I(n_test) * self.jitter - - return dx.MultivariateNormalFullCovariance( - jnp.atleast_1d(μt.squeeze()), Ktt - ) + return npd.MultivariateNormal(jnp.atleast_1d(μt.squeeze()), Ktt) return predict_fn @@ -161,7 +158,7 @@ class AbstractPosterior(AbstractGP): jitter: Optional[float] = DEFAULT_JITTER @abstractmethod - def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: + def predict(self, *args: Any, **kwargs: Any) -> npd.Distribution: """Predict the GP's output given the input. Args: @@ -199,7 +196,7 @@ class ConjugatePosterior(AbstractPosterior): def predict( self, train_data: Dataset, params: Dict - ) -> Callable[[Float[Array, "N D"]], dx.Distribution]: + ) -> Callable[[Float[Array, "N D"]], npd.Distribution]: """Conditional on a set of training data, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density. Args: @@ -207,7 +204,7 @@ def predict( params (Dict): A dictionary of parameters that should be used to compute the posterior. Returns: - Callable[[Float[Array, "N D"]], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `distrax.MultivariateNormalFullCovariance`. + Callable[[Float[Array, "N D"]], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `numpyro.distributions.MultivariateNormal`. """ x, y, n = train_data.X, train_data.y, train_data.n @@ -226,7 +223,7 @@ def predict( # w = L⁻¹ (y - μx) w = jsp.linalg.solve_triangular(L, y - μx, lower=True) - def predict(test_inputs: Float[Array, "N D"]) -> dx.Distribution: + def predict(test_inputs: Float[Array, "N D"]) -> npd.Distribution: t = test_inputs n_test = t.shape[0] μt = self.prior.mean_function(t, params["mean_function"]) @@ -243,9 +240,7 @@ def predict(test_inputs: Float[Array, "N D"]) -> dx.Distribution: covariance = Ktt - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt) covariance += I(n_test) * self.jitter - return dx.MultivariateNormalFullCovariance( - jnp.atleast_1d(mean.squeeze()), covariance - ) + return npd.MultivariateNormal(jnp.atleast_1d(mean.squeeze()), covariance) return predict @@ -281,9 +276,8 @@ def mll( L = jnp.linalg.cholesky(Sigma) # p(y | x, θ), where θ are the model hyperparameters: - marginal_likelihood = dx.MultivariateNormalTri( - jnp.atleast_1d(μx.squeeze()), L - ) + + marginal_likelihood = npd.MultivariateNormal(jnp.atleast_1d(μx.squeeze()), scale_tril=L) # log p(θ) log_prior_density = evaluate_priors(params, priors) @@ -317,15 +311,16 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: def predict( self, train_data: Dataset, params: Dict - ) -> Callable[[Float[Array, "N D"]], dx.Distribution]: + ) -> Callable[[Float[Array, "N D"]], npd.Distribution]: """Conditional on a set of training data, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density. Note, to gain predictions on the scale of the original data, the returned distribution will need to be transformed through the likelihood function's inverse link function. - Args: - train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. - params (Dict): A dictionary of parameters that should be used to compute the posterior. + Args: + train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. + params (Dict): A dictionary of parameters that should be used to compute the posterior. - Returns: - Callable[[Float[Array, "N D"]], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `distrax.MultivariateNormalFullCovariance`. + Returns: + <<<<<<< HEAD + tp.Callable[[Array], npd.Distribution]: A function that accepts an input array and returns the predictive distribution as a `numpyro.distributions.MultivariateNormal`. """ x, n = train_data.X, train_data.n @@ -333,7 +328,7 @@ def predict( Kxx += I(n) * self.jitter Lx = jnp.linalg.cholesky(Kxx) - def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: + def predict_fn(test_inputs: Float[Array, "N D"]) -> npd.Distribution: t = test_inputs n_test = t.shape[0] Ktx = cross_covariance(self.prior.kernel, t, x, params["kernel"]) @@ -350,9 +345,7 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) covariance += I(n_test) * self.jitter - return dx.MultivariateNormalFullCovariance( - jnp.atleast_1d(mean.squeeze()), covariance - ) + return npd.MultivariateNormal(jnp.atleast_1d(mean.squeeze()), covariance) return predict_fn @@ -376,7 +369,7 @@ def marginal_log_likelihood( if not priors: priors = copy_dict_structure(self._initialise_params(jr.PRNGKey(0))) - priors["latent"] = dx.Normal(loc=0.0, scale=1.0) + priors["latent"] = npd.Normal(loc=0.0, scale=1.0) def mll(params: Dict): Kxx = gram(self.prior.kernel, x, params["kernel"]) @@ -399,18 +392,14 @@ def mll(params: Dict): return mll -def construct_posterior( - prior: Prior, likelihood: AbstractLikelihood -) -> AbstractPosterior: +def construct_posterior(prior: Prior, likelihood: AbstractLikelihood) -> AbstractPosterior: if isinstance(likelihood, Conjugate): PosteriorGP = ConjugatePosterior elif isinstance(likelihood, NonConjugate): PosteriorGP = NonConjugatePosterior else: - raise NotImplementedError( - f"No posterior implemented for {likelihood.name} likelihood" - ) + raise NotImplementedError(f"No posterior implemented for {likelihood.name} likelihood") return PosteriorGP(prior=prior, likelihood=likelihood) diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index a2bec329..e9bf3173 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -16,9 +16,9 @@ import abc from typing import Any, Callable, Dict, Optional -import distrax as dx import jax.numpy as jnp import jax.scipy as jsp +import numpyro.distributions as npd from chex import dataclass from jaxtyping import Array, Float @@ -33,7 +33,7 @@ class AbstractLikelihood: num_datapoints: int # The number of datapoints that the likelihood factorises over. name: Optional[str] = "Likelihood" - def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: + def __call__(self, *args: Any, **kwargs: Any) -> npd.Distribution: """Evaluate the likelihood function at a given predictive distribution. Args: @@ -46,7 +46,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: return self.predict(*args, **kwargs) @abc.abstractmethod - def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: + def predict(self, *args: Any, **kwargs: Any) -> npd.Distribution: """Evaluate the likelihood function at a given predictive distribution. Args: @@ -116,24 +116,24 @@ def link_function(self) -> Callable: Callable: A link function that maps the predictive distribution to the likelihood function. """ - def link_fn(x, params: Dict) -> dx.Distribution: - return dx.Normal(loc=x, scale=params["obs_noise"]) + def link_fn(x, params: Dict) -> npd.Distribution: + return npd.Normal(loc=x, scale=params["obs_noise"]) return link_fn - def predict(self, dist: dx.Distribution, params: Dict) -> dx.Distribution: + def predict(self, dist: npd.Distribution, params: Dict) -> npd.Distribution: """Evaluate the Gaussian likelihood function at a given predictive distribution. Computationally, this is equivalent to summing the observation noise term to the diagonal elements of the predictive distribution's covariance matrix. Args: - dist (dx.Distribution): The Gaussian process posterior, evaluated at a finite set of test points. + dist (numpyro.distributions.Distribution): The Gaussian process posterior, evaluated at a finite set of test points. params (Dict): The parameters of the likelihood function. Returns: - dx.Distribution: The predictive distribution. + numpyro.distributions.Distribution: The predictive distribution. """ n_data = dist.event_shape[0] - noisy_cov = dist.covariance() + I(n_data) * params["likelihood"]["obs_noise"] - return dx.MultivariateNormalFullCovariance(dist.mean(), noisy_cov) + noisy_cov = dist.covariance_matrix + I(n_data) * params["likelihood"]["obs_noise"] + return npd.MultivariateNormal(dist.mean, noisy_cov) @dataclass @@ -159,8 +159,8 @@ def link_function(self) -> Callable: Callable: A probit link function that maps the predictive distribution to the likelihood function. """ - def link_fn(x, params: Dict) -> dx.Distribution: - return dx.Bernoulli(probs=inv_probit(x)) + def link_fn(x, params: Dict) -> npd.Distribution: + return npd.Bernoulli(probs=inv_probit(x)) return link_fn @@ -172,26 +172,24 @@ def predictive_moment_fn(self) -> Callable: Callable: A callable object that accepts a mean and variance term from which the predictive random variable is computed. """ - def moment_fn( - mean: Float[Array, "N D"], variance: Float[Array, "N D"], params: Dict - ): + def moment_fn(mean: Float[Array, "N D"], variance: Float[Array, "N D"], params: Dict): rv = self.link_function(mean / jnp.sqrt(1 + variance), params) return rv return moment_fn - def predict(self, dist: dx.Distribution, params: Dict) -> dx.Distribution: + def predict(self, dist: npd.Distribution, params: Dict) -> npd.Distribution: """Evaluate the pointwise predictive distribution, given a Gaussian process posterior and likelihood parameters. Args: - dist (dx.Distribution): The Gaussian process posterior, evaluated at a finite set of test points. + dist (npd.Distribution): The Gaussian process posterior, evaluated at a finite set of test points. params (Dict): The parameters of the likelihood function. Returns: - dx.Distribution: The pointwise predictive distribution. + npd.Distribution: The pointwise predictive distribution. """ - variance = jnp.diag(dist.covariance()) - mean = dist.mean() + variance = jnp.diag(dist.covariance_matrix) + mean = dist.mean return self.predictive_moment_fn(mean.ravel(), variance, params) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 1360fe85..6a8d8c33 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -18,10 +18,10 @@ from typing import Dict, Tuple from warnings import warn -import distrax as dx import jax import jax.numpy as jnp import jax.random as jr +import numpyro.distributions as npd from chex import dataclass from jaxtyping import Array, Float @@ -29,8 +29,6 @@ from .types import PRNGKeyType from .utils import merge_dictionaries -Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x) - ################################ # Base operations @@ -39,7 +37,6 @@ class ParameterState: """The state of the model. This includes the parameter set, which parameters are to be trained and bijectors that allow parameters to be constrained and unconstrained.""" - params: Dict trainables: Dict bijectors: Dict @@ -161,7 +158,7 @@ def recursive_bijectors(ps, bs) -> Tuple[Dict, Dict]: transform_type = transform_set[key] bijector = transform_set[transform_type] else: - bijector = Identity + bijector = npd.transforms.IdentityTransform() warnings.warn( f"Parameter {key} has no transform. Defaulting to identity transfom." ) @@ -181,8 +178,7 @@ def constrain(params: Dict, bijectors: Dict) -> Dict: Returns: Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. """ - - map = lambda param, trans: trans.forward(param) + map = lambda param, trans: trans.inv(param) return jax.tree_util.tree_map(map, params, bijectors) @@ -198,7 +194,7 @@ def unconstrain(params: Dict, bijectors: Dict) -> Dict: Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. """ - map = lambda param, trans: trans.inverse(param) + map = lambda param, trans: trans.__call__(param) return jax.tree_util.tree_map(map, params, bijectors) @@ -207,7 +203,7 @@ def unconstrain(params: Dict, bijectors: Dict) -> Dict: # Priors ################################ def log_density( - param: Float[Array, "D"], density: dx.Distribution + param: Float[Array, "D"], density: npd.Distribution ) -> Float[Array, "1"]: """Compute the log density of a parameter given a distribution. @@ -289,17 +285,17 @@ def prior_checks(priors: Dict) -> Dict: if "latent" in priors.keys(): latent_prior = priors["latent"] if latent_prior is not None: - if latent_prior.name != "Normal": + if not isinstance(latent_prior, npd.Normal): warnings.warn( - f"A {latent_prior.name} distribution prior has been placed on" + f"A {type(latent_prior)} distribution prior has been placed on" " the latent function. It is strongly advised that a" " unit Gaussian prior is used." ) else: warnings.warn("Placing unit Gaussian prior on latent function.") - priors["latent"] = dx.Normal(loc=0.0, scale=1.0) + priors["latent"] = npd.Normal(loc=0.0, scale=1.0) else: - priors["latent"] = dx.Normal(loc=0.0, scale=1.0) + priors["latent"] = npd.Normal(loc=0.0, scale=1.0) return priors diff --git a/gpjax/quadrature.py b/gpjax/quadrature.py index 8312f6e5..839c6b9d 100644 --- a/gpjax/quadrature.py +++ b/gpjax/quadrature.py @@ -26,7 +26,7 @@ def gauss_hermite_quadrature( fun: Callable, mean: Float[Array, "N D"], - var: Float[Array, "N D"], + sd: Float[Array, "N D"], deg: Optional[int] = DEFAULT_NUM_GAUSS_HERMITE_POINTS, *args, **kwargs @@ -36,15 +36,14 @@ def gauss_hermite_quadrature( Args: fun (Callable): The function for which quadrature should be applied to. mean (Float[Array, "N D"]): The mean of the Gaussian distribution that is used to shift quadrature points. - var (Float[Array, "N D"]): The variance of the Gaussian distribution that is used to scale quadrature points. + sd (Float[Array, "N D"]): The standard deviation of the Gaussian distribution that is used to scale quadrature points. deg (int, optional): The number of quadrature points that are to be used. Defaults to 20. Returns: Float[Array, "N"]: The evaluated integrals value. """ gh_points, gh_weights = np.polynomial.hermite.hermgauss(deg) - stdev = jnp.sqrt(var) - X = mean + jnp.sqrt(2.0) * stdev * gh_points + X = mean + jnp.sqrt(2.0) * sd * gh_points W = gh_weights / jnp.sqrt(jnp.pi) return jnp.sum(fun(X, *args, **kwargs) * W, axis=1) diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 0a939fd8..96d9bfa9 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -16,9 +16,9 @@ import abc from typing import Any, Callable, Dict, Optional -import distrax as dx import jax.numpy as jnp import jax.scipy as jsp +import numpyro.distributions as npd from chex import dataclass from jaxtyping import Array, Float @@ -36,7 +36,7 @@ class AbstractVariationalFamily: """Abstract base class used to represent families of distributions that can be used within variational inference.""" - def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: + def __call__(self, *args: Any, **kwargs: Any) -> npd.Distribution: """For a given set of parameters, compute the latent function's prediction under the variational approximation. Args: @@ -61,7 +61,7 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: raise NotImplementedError @abc.abstractmethod - def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: + def predict(self, *args: Any, **kwargs: Any) -> npd.Distribution: """Predict the GP's output given the input. Args: @@ -142,12 +142,13 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: Kzz += I(m) * self.jitter Lz = jnp.linalg.cholesky(Kzz) - qu = dx.MultivariateNormalTri(jnp.atleast_1d(mu.squeeze()), sqrt) - pu = dx.MultivariateNormalTri(jnp.atleast_1d(μz.squeeze()), Lz) + qu = npd.MultivariateNormal(jnp.atleast_1d(mu.squeeze()), scale_tril=sqrt) + pu = npd.MultivariateNormal(jnp.atleast_1d(μz.squeeze()), scale_tril=Lz) + return kld_dense_dense(qu, pu) - return qu.kl_divergence(pu) - - def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distribution]: + def predict( + self, params: Dict + ) -> Callable[[Float[Array, "N D"]], npd.Distribution]: """Compute the predictive distribution of the GP at the test inputs t. This is the integral q(f(t)) = ∫ p(f(t)|u) q(u) du, which can be computed in closed form as: @@ -158,7 +159,7 @@ def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distributi params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: - Callable[[Float[Array, "N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + Callable[[Float[Array, "N D"]], npd.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ mu = params["variational_family"]["moments"]["variational_mean"] sqrt = params["variational_family"]["moments"]["variational_root_covariance"] @@ -170,7 +171,7 @@ def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distributi Lz = jnp.linalg.cholesky(Kzz) μz = self.prior.mean_function(z, params["mean_function"]) - def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: + def predict_fn(test_inputs: Float[Array, "N D"]) -> npd.Distribution: t = test_inputs n_test = t.shape[0] Ktt = gram(self.prior.kernel, t, params["kernel"]) @@ -197,9 +198,7 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: ) covariance += I(n_test) * self.jitter - return dx.MultivariateNormalFullCovariance( - jnp.atleast_1d(mean.squeeze()), covariance - ) + return npd.MultivariateNormal(jnp.atleast_1d(mean.squeeze()), covariance) return predict_fn @@ -228,14 +227,13 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: """ mu = params["variational_family"]["moments"]["variational_mean"] sqrt = params["variational_family"]["moments"]["variational_root_covariance"] - m = self.num_inducing - - qu = dx.MultivariateNormalTri(jnp.atleast_1d(mu.squeeze()), sqrt) - pu = dx.MultivariateNormalDiag(jnp.zeros(m)) - return qu.kl_divergence(pu) + qu = npd.MultivariateNormal(jnp.atleast_1d(mu.squeeze()), scale_tril=sqrt) + return kld_dense_white(qu) - def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distribution]: + def predict( + self, params: Dict + ) -> Callable[[Float[Array, "N D"]], npd.Distribution]: """Compute the predictive distribution of the GP at the test inputs t. This is the integral q(f(t)) = ∫ p(f(t)|u) q(u) du, which can be computed in closed form as @@ -246,7 +244,7 @@ def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distributi params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: - Callable[[Float[Array, "N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + Callable[[Float[Array, "N D"]], npd.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ mu = params["variational_family"]["moments"]["variational_mean"] sqrt = params["variational_family"]["moments"]["variational_root_covariance"] @@ -257,7 +255,7 @@ def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distributi Kzz += I(m) * self.jitter Lz = jnp.linalg.cholesky(Kzz) - def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: + def predict_fn(test_inputs: Float[Array, "N D"]) -> npd.Distribution: t = test_inputs n_test = t.shape[0] Ktt = gram(self.prior.kernel, t, params["kernel"]) @@ -281,9 +279,7 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: ) covariance += I(n_test) * self.jitter - return dx.MultivariateNormalFullCovariance( - jnp.atleast_1d(mean.squeeze()), covariance - ) + return npd.MultivariateNormal(jnp.atleast_1d(mean.squeeze()), covariance) return predict_fn @@ -360,12 +356,14 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: Kzz += I(m) * self.jitter Lz = jnp.linalg.cholesky(Kzz) - qu = dx.MultivariateNormalTri(jnp.atleast_1d(mu.squeeze()), sqrt) - pu = dx.MultivariateNormalTri(jnp.atleast_1d(μz.squeeze()), Lz) + qu = npd.MultivariateNormal(jnp.atleast_1d(mu.squeeze()), scale_tril=sqrt) + pu = npd.MultivariateNormal(jnp.atleast_1d(μz.squeeze()), scale_tril=Lz) - return qu.kl_divergence(pu) + return kld_dense_dense(qu, pu) - def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distribution]: + def predict( + self, params: Dict + ) -> Callable[[Float[Array, "N D"]], npd.Distribution]: """Compute the predictive distribution of the GP at the test inputs t. This is the integral q(f(t)) = ∫ p(f(t)|u) q(u) du, which can be computed in closed form as @@ -378,7 +376,7 @@ def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distributi params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: - Callable[[Float[Array, "N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + Callable[[Float[Array, "N D"]], npd.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ natural_vector = params["variational_family"]["moments"]["natural_vector"] natural_matrix = params["variational_family"]["moments"]["natural_matrix"] @@ -408,7 +406,7 @@ def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distributi Lz = jnp.linalg.cholesky(Kzz) μz = self.prior.mean_function(z, params["mean_function"]) - def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: + def predict_fn(test_inputs: Float[Array, "N D"]) -> npd.Distribution: t = test_inputs Ktt = gram(self.prior.kernel, t, params["kernel"]) Kzt = cross_covariance(self.prior.kernel, z, t, params["kernel"]) @@ -433,9 +431,7 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: + jnp.matmul(Ktz_Kzz_inv_L, Ktz_Kzz_inv_L.T) ) - return dx.MultivariateNormalFullCovariance( - jnp.atleast_1d(mean.squeeze()), covariance - ) + return npd.MultivariateNormal(jnp.atleast_1d(mean.squeeze()), covariance) return predict_fn @@ -510,12 +506,14 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: Kzz += I(m) * self.jitter Lz = jnp.linalg.cholesky(Kzz) - qu = dx.MultivariateNormalTri(jnp.atleast_1d(mu.squeeze()), sqrt) - pu = dx.MultivariateNormalTri(jnp.atleast_1d(μz.squeeze()), Lz) + qu = npd.MultivariateNormal(jnp.atleast_1d(mu.squeeze()), scale_tril=sqrt) + pu = npd.MultivariateNormal(jnp.atleast_1d(μz.squeeze()), scale_tril=Lz) - return qu.kl_divergence(pu) + return kld_dense_dense(qu, pu) - def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distribution]: + def predict( + self, params: Dict + ) -> Callable[[Float[Array, "N D"]], npd.Distribution]: """Compute the predictive distribution of the GP at the test inputs t. This is the integral q(f(t)) = ∫ p(f(t)|u) q(u) du, which can be computed in closed form as @@ -528,7 +526,7 @@ def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distributi params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: - Callable[[Float[Array, "N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + Callable[[Float[Array, "N D"]], npd.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ expectation_vector = params["variational_family"]["moments"][ "expectation_vector" @@ -554,7 +552,7 @@ def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distributi Lz = jnp.linalg.cholesky(Kzz) μz = self.prior.mean_function(z, params["mean_function"]) - def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: + def predict_fn(test_inputs: Float[Array, "N D"]) -> npd.Distribution: t = test_inputs Ktt = gram(self.prior.kernel, t, params["kernel"]) Kzt = cross_covariance(self.prior.kernel, z, t, params["kernel"]) @@ -579,9 +577,7 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T) ) - return dx.MultivariateNormalFullCovariance( - jnp.atleast_1d(mean.squeeze()), covariance - ) + return npd.MultivariateNormal(jnp.atleast_1d(mean.squeeze()), covariance) return predict_fn @@ -619,12 +615,12 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: def predict( self, train_data: Dataset, params: Dict - ) -> Callable[[Float[Array, "N D"]], dx.Distribution]: + ) -> Callable[[Float[Array, "N D"]], npd.Distribution]: """Compute the predictive distribution of the GP at the test inputs. Args: params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: - Callable[[Float[Array, "N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + Callable[[Float[Array, "N D"]], npd.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ x, y = train_data.X, train_data.y @@ -662,7 +658,7 @@ def predict( Lz.T, Lz_inv_Kzx_diff, lower=False ) - def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: + def predict_fn(test_inputs: Float[Array, "N D"]) -> npd.Distribution: t = test_inputs Ktt = gram(self.prior.kernel, t, params["kernel"]) Kzt = cross_covariance(self.prior.kernel, z, t, params["kernel"]) @@ -684,13 +680,86 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: + jnp.matmul(L_inv_Lz_inv_Kzt.T, L_inv_Lz_inv_Kzt) ) - return dx.MultivariateNormalFullCovariance( - jnp.atleast_1d(mean.squeeze()), covariance - ) + return npd.MultivariateNormal(jnp.atleast_1d(mean.squeeze()), covariance) return predict_fn +# TODO: Abstract these out to a KL divergence that accepts a linear operator to facilate structured covarainces other than dense. +def kld_dense_dense( + q: npd.MultivariateNormal, p: npd.MultivariateNormal +) -> Float[Array, "1"]: + """Kullback-Leibler divergence KL[q(x)||p(x)] between two dense covariance Gaussian distributions + q(x) = N(x; μq, Σq) and p(x) = N(x; μp, Σp). + + Args: + q (npd.MultivariateNormal): A multivariate Gaussian distribution. + p (npd.MultivariateNormal): A multivariate Gaussian distribution. + + Returns: + Float[Array, "1"]: The KL divergence between the two distributions. + """ + + q_mu = q.loc + q_sqrt = q.scale_tril + n = q_mu.shape[-1] + + p_mu = p.loc + p_sqrt = p.scale_tril + + diag = jnp.diag(q_sqrt) + + # Trace term tr(Σp⁻¹ Σq) + trace = jnp.sum(jnp.square(jsp.linalg.solve_triangular(p_sqrt, q_sqrt, lower=True))) + + # Mahalanobis term: μqᵀ Σp⁻¹ μq + alpha = jsp.linalg.solve_triangular(p_sqrt, p_mu - q_mu, lower=True) + mahalanobis = jnp.sum(jnp.square(alpha)) + + # log|Σq| + logdet_qcov = jnp.sum(jnp.log(jnp.square(diag))) + two_kl = mahalanobis - n - logdet_qcov + trace + + # log|Σp| + log_det_pcov = jnp.sum(jnp.log(jnp.square(jnp.diag(p_sqrt)))) + two_kl += log_det_pcov + + return two_kl / 2.0 + + +def kld_dense_white(q: npd.MultivariateNormal) -> Float[Array, "1"]: + """Kullback-Leibler divergence KL[q(x)||p(x)] between a dense covariance Gaussian distribution + q(x) = N(x; μq, Σq), and white indenity Gaussian p(x) = N(x; 0, I). + + This is useful for variational inference with a whitened variational family. + + Args: + q (npd.MultivariateNormal): A multivariate Gaussian distribution. + + Returns: + Float[Array, "1"]: The KL divergence between the two distributions. + """ + + q_mu = q.loc + q_sqrt = q.scale_tril + n = q_mu.shape[-1] + + diag = jnp.diag(q_sqrt) + + # Trace term tr(Σp⁻¹ Σq), and alpha for Mahalanobis term: + alpha = q_mu + trace = jnp.sum(jnp.square(q_sqrt)) + + # Mahalanobis term: μqᵀ Σp⁻¹ μq + mahalanobis = jnp.sum(jnp.square(alpha)) + + # log|Σq| (no log|Σp| as this is just zero!) + logdet_qcov = jnp.sum(jnp.log(jnp.square(diag))) + two_kl = mahalanobis - n - logdet_qcov + trace + + return two_kl / 2.0 + + __all__ = [ "AbstractVariationalFamily", "AbstractVariationalGaussian", diff --git a/gpjax/variational_inference.py b/gpjax/variational_inference.py index 16e1e604..7d701fd0 100644 --- a/gpjax/variational_inference.py +++ b/gpjax/variational_inference.py @@ -117,21 +117,21 @@ def variational_expectation( Array: The expectation of the model's log-likelihood under our variational distribution. """ x, y = batch.X, batch.y + link_fn = self.likelihood.link_function - # q(f(x)) - predictive_dist = vmap(self.variational_family.predict(params))(x[:, None]) - mean = predictive_dist.mean().val.reshape(-1, 1) - variance = predictive_dist.variance().val.reshape(-1, 1) + # variational distribution q(f(.)) = N(f(.); μ(.), Σ(., .)) + q = self.variational_family(params) + + # μ(x) and √diag(Σ(x, x)) + (mean, sd), _ = vmap(lambda x_i: q(x_i).tree_flatten())(x[:, None]) # log(p(y|f(x))) - log_prob = vmap( - lambda f, y: self.likelihood.link_function( - f, params["likelihood"] - ).log_prob(y) - ) + log_prob = vmap(lambda f, y: link_fn(f, params["likelihood"]).log_prob(y)) # ≈ ∫[log(p(y|f(x))) q(f(x))] df(x) - expectation = gauss_hermite_quadrature(log_prob, mean, variance, y=y) + expectation = gauss_hermite_quadrature( + log_prob, mean.reshape(-1, 1), sd.reshape(-1, 1), y=y + ) return expectation diff --git a/paper.bib b/paper/paper.bib similarity index 100% rename from paper.bib rename to paper/paper.bib diff --git a/paper.md b/paper/paper.md similarity index 100% rename from paper.md rename to paper/paper.md diff --git a/paper/paper.pdf b/paper/paper.pdf new file mode 100644 index 00000000..bd501dd0 Binary files /dev/null and b/paper/paper.pdf differ diff --git a/setup.py b/setup.py index 79fa846c..080c3e41 100644 --- a/setup.py +++ b/setup.py @@ -16,8 +16,7 @@ def parse_requirements_file(filename): "jaxlib>=0.1.47", "optax", "chex", - "distrax>=0.1.2", - "tensorflow-probability>=0.16.0", + "numpyro", "tqdm>=4.0.0", "ml-collections==0.1.0", "jaxtyping>=0.0.2", diff --git a/tests/test_config.py b/tests/test_config.py index b2198a1c..2dcf5d6d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,26 +1,25 @@ from ml_collections import ConfigDict -from tensorflow_probability.substrates.jax import bijectors as tfb - +import numpyro.distributions as npd from gpjax.config import add_parameter, get_defaults def test_add_parameter(): - add_parameter("test_parameter", tfb.Identity()) + add_parameter("test_parameter", npd.transforms.IdentityTransform()) config = get_defaults() assert "test_parameter" in config.transformations assert "test_parameter_transform" in config.transformations assert config.transformations["test_parameter"] == "test_parameter_transform" - assert isinstance(config.transformations["test_parameter_transform"], tfb.Bijector) + assert isinstance(config.transformations["test_parameter_transform"], npd.transforms.Transform) def test_add_parameter(): config = get_defaults() - add_parameter("test_parameter", tfb.Identity()) + add_parameter("test_parameter", npd.transforms.IdentityTransform()) config = get_defaults() assert "test_parameter" in config.transformations assert "test_parameter_transform" in config.transformations assert config.transformations["test_parameter"] == "test_parameter_transform" - assert isinstance(config.transformations["test_parameter_transform"], tfb.Bijector) + assert isinstance(config.transformations["test_parameter_transform"], npd.transforms.Transform) def test_get_defaults(): diff --git a/tests/test_gp.py b/tests/test_gp.py index 5a940f53..c99242e8 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -1,9 +1,9 @@ import typing as tp -import distrax as dx import jax.numpy as jnp import jax.random as jr import pytest +import numpyro.distributions as npd from gpjax import Dataset, initialise from gpjax.gps import ( @@ -32,9 +32,9 @@ def test_prior(num_datapoints): x = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) predictive_dist = prior_rv_fn(x) - assert isinstance(predictive_dist, dx.Distribution) - mu = predictive_dist.mean() - sigma = predictive_dist.covariance() + assert isinstance(predictive_dist, npd.Distribution) + mu = predictive_dist.mean + sigma = predictive_dist.covariance_matrix assert mu.shape == (num_datapoints,) assert sigma.shape == (num_datapoints, num_datapoints) @@ -75,10 +75,10 @@ def test_conjugate_posterior(num_datapoints): x = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) predictive_dist = predictive_dist_fn(x) - assert isinstance(predictive_dist, dx.Distribution) + assert isinstance(predictive_dist, npd.Distribution) - mu = predictive_dist.mean() - sigma = predictive_dist.covariance() + mu = predictive_dist.mean + sigma = predictive_dist.covariance_matrix assert mu.shape == (num_datapoints,) assert sigma.shape == (num_datapoints, num_datapoints) @@ -117,10 +117,10 @@ def test_nonconjugate_posterior(num_datapoints, likel): x = jnp.linspace(-3.0, 3.0, num_datapoints).reshape(-1, 1) predictive_dist = predictive_dist_fn(x) - assert isinstance(predictive_dist, dx.Distribution) + assert isinstance(predictive_dist, npd.Distribution) - mu = predictive_dist.mean() - sigma = predictive_dist.covariance() + mu = predictive_dist.mean + sigma = predictive_dist.covariance_matrix assert mu.shape == (num_datapoints,) assert sigma.shape == (num_datapoints, num_datapoints) diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index 57f0ea09..e9b48f32 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -1,8 +1,8 @@ import typing as tp -import distrax as dx import jax.numpy as jnp import jax.random as jr +import numpyro.distributions as npd import pytest from gpjax.likelihoods import ( @@ -39,8 +39,8 @@ def test_predictive_moment(n): pred_mom_fn = lhood.predictive_moment_fn params, _, _ = initialise(lhood, key).unpack() rv = pred_mom_fn(fmean, fvar, params) - mu = rv.mean() - sigma = rv.variance() + mu = rv.mean + sigma = rv.variance assert isinstance(lhood.predictive_moment_fn, tp.Callable) assert mu.shape == (n,) assert sigma.shape == (n,) @@ -57,7 +57,7 @@ def test_link_fns(lik: AbstractLikelihood, n: int): x = jnp.linspace(-3.0, 3.0).reshape(-1, 1) l_eval = link_fn(x, params) - assert isinstance(l_eval, dx.Distribution) + assert isinstance(l_eval, npd.Distribution) @pytest.mark.parametrize("noise", [0.1, 0.5, 1.0]) @@ -65,33 +65,33 @@ def test_call_gaussian(noise): key = jr.PRNGKey(123) n = 10 lhood = Gaussian(num_datapoints=n) - dist = dx.MultivariateNormalFullCovariance(jnp.zeros(n), jnp.eye(n)) + dist = npd.MultivariateNormal(jnp.zeros(n), covariance_matrix=jnp.eye(n)) params = {"likelihood": {"obs_noise": noise}} l_dist = lhood(dist, params) - assert (l_dist.mean() == jnp.zeros(n)).all() + assert (l_dist.mean == jnp.zeros(n)).all() noise_mat = jnp.diag(jnp.repeat(noise, n)) - assert (l_dist.covariance() == jnp.eye(n) + noise_mat).all() + assert (l_dist.covariance_matrix == jnp.eye(n) + noise_mat).all() l_dist = lhood.predict(dist, params) - assert (l_dist.mean() == jnp.zeros(n)).all() + assert (l_dist.mean == jnp.zeros(n)).all() noise_mat = jnp.diag(jnp.repeat(noise, n)) - assert (l_dist.covariance() == jnp.eye(n) + noise_mat).all() + assert (l_dist.covariance_matrix == jnp.eye(n) + noise_mat).all() def test_call_bernoulli(): n = 10 lhood = Bernoulli(num_datapoints=n) - dist = dx.MultivariateNormalFullCovariance(jnp.zeros(n), jnp.eye(n)) + dist = npd.MultivariateNormal(jnp.zeros(n), covariance_matrix=jnp.eye(n)) params = {"likelihood": {}} l_dist = lhood(dist, params) - assert (l_dist.mean() == 0.5 * jnp.ones(n)).all() - assert (l_dist.variance() == 0.25 * jnp.ones(n)).all() + assert (l_dist.mean == 0.5 * jnp.ones(n)).all() + assert (l_dist.variance == 0.25 * jnp.ones(n)).all() l_dist = lhood.predict(dist, params) - assert (l_dist.mean() == 0.5 * jnp.ones(n)).all() - assert (l_dist.variance() == 0.25 * jnp.ones(n)).all() + assert (l_dist.mean == 0.5 * jnp.ones(n)).all() + assert (l_dist.variance == 0.25 * jnp.ones(n)).all() @pytest.mark.parametrize("lik", [Gaussian, Bernoulli]) diff --git a/tests/test_parameters.py b/tests/test_parameters.py index f384d046..94c377bd 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -1,10 +1,9 @@ import typing as tp -import distrax as dx import jax.numpy as jnp +import numpyro.distributions as npd import jax.random as jr import pytest -from tensorflow_probability.substrates.jax import distributions as tfd from gpjax.gps import Prior from gpjax.kernels import RBF @@ -57,7 +56,7 @@ def test_non_conjugate_initialise(): @pytest.mark.parametrize("x", [-1.0, 0.0, 1.0]) def test_lpd(x): val = jnp.array(x) - dist = tfd.Normal(loc=0.0, scale=1.0) + dist = npd.Normal(loc=0.0, scale=1.0) lpd = log_density(val, dist) assert lpd is not None assert log_density(val, None) == 0.0 @@ -81,7 +80,7 @@ def test_recursive_complete(lik): posterior = Prior(kernel=RBF()) * lik(num_datapoints=10) params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() priors = {"kernel": {}} - priors["kernel"]["lengthscale"] = tfd.HalfNormal(scale=2.0) + priors["kernel"]["lengthscale"] = npd.HalfNormal(scale=2.0) container = copy_dict_structure(params) complete_priors = recursive_complete(container, priors) for ( @@ -90,7 +89,7 @@ def test_recursive_complete(lik): v2, ) in recursive_items(params, complete_priors): if k == "lengthscale": - assert isinstance(v2, tfd.HalfNormal) + assert isinstance(v2, npd.HalfNormal) else: assert v2 == None @@ -109,10 +108,10 @@ def test_prior_evaluation(): } priors = { "kernel": { - "lengthscale": tfd.Gamma(1.0, 1.0), - "variance": tfd.Gamma(2.0, 2.0), + "lengthscale": npd.Gamma(1.0, 1.0), + "variance": npd.Gamma(2.0, 2.0), }, - "likelihood": {"obs_noise": tfd.Gamma(3.0, 3.0)}, + "likelihood": {"obs_noise": npd.Gamma(3.0, 3.0)}, } lpd = evaluate_priors(params, priors) assert pytest.approx(lpd) == -2.0110168 @@ -147,8 +146,8 @@ def test_incomplete_priors(): } priors = { "kernel": { - "lengthscale": tfd.Gamma(1.0, 1.0), - "variance": tfd.Gamma(2.0, 2.0), + "lengthscale": npd.Gamma(1.0, 1.0), + "variance": npd.Gamma(2.0, 2.0), }, } container = copy_dict_structure(params) @@ -164,7 +163,7 @@ def test_checks(num_datapoints): priors = prior_checks(incomplete_priors) assert "latent" in priors.keys() assert "variance" not in priors.keys() - assert isinstance(priors["latent"], dx.Normal) + assert isinstance(priors["latent"], npd.Normal) def test_structure_priors(): @@ -172,8 +171,8 @@ def test_structure_priors(): params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() priors = { "kernel": { - "lengthscale": tfd.Gamma(1.0, 1.0), - "variance": tfd.Gamma(2.0, 2.0), + "lengthscale": npd.Gamma(1.0, 1.0), + "variance": npd.Gamma(2.0, 2.0), }, } structured_priors = structure_priors(params, priors) @@ -189,7 +188,7 @@ def recursive_fn(d1, d2, fn: tp.Callable[[tp.Any], tp.Any]): assert v -@pytest.mark.parametrize("latent_prior", [dx.Laplace(0.0, 1.0), tfd.Laplace(0.0, 1.0)]) +@pytest.mark.parametrize("latent_prior", [npd.Laplace(0.0, 1.0), npd.Laplace(0.0, 1.0)]) def test_prior_checks(latent_prior): priors = { "kernel": {"lengthscale": None, "variance": None}, @@ -199,7 +198,7 @@ def test_prior_checks(latent_prior): } new_priors = prior_checks(priors) assert "latent" in new_priors.keys() - assert new_priors["latent"].name == "Normal" + assert isinstance(new_priors["latent"], npd.Normal) priors = { "kernel": {"lengthscale": None, "variance": None}, @@ -208,7 +207,7 @@ def test_prior_checks(latent_prior): } new_priors = prior_checks(priors) assert "latent" in new_priors.keys() - assert new_priors["latent"].name == "Normal" + assert isinstance(new_priors["latent"], npd.Normal) priors = { "kernel": {"lengthscale": None, "variance": None}, @@ -219,7 +218,7 @@ def test_prior_checks(latent_prior): with pytest.warns(UserWarning): new_priors = prior_checks(priors) assert "latent" in new_priors.keys() - assert new_priors["latent"].name == "Laplace" + assert isinstance(new_priors["latent"], npd.Laplace) ######################### @@ -233,8 +232,8 @@ def test_output(num_datapoints, likelihood): assert isinstance(bijectors, dict) for k, v1, v2 in recursive_items(bijectors, bijectors): - assert isinstance(v1.forward, tp.Callable) - assert isinstance(v2.inverse, tp.Callable) + assert isinstance(v1.__call__, tp.Callable) + assert isinstance(v2.inv, tp.Callable) unconstrained_params = unconstrain(params, bijectors) assert ( @@ -252,5 +251,5 @@ def test_output(num_datapoints, likelihood): a_bijectors = build_bijectors(augmented_params) assert "test_param" in list(a_bijectors.keys()) - assert a_bijectors["test_param"].forward(jnp.array([1.0])) == 1.0 - assert a_bijectors["test_param"].inverse(jnp.array([1.0])) == 1.0 + assert a_bijectors["test_param"](jnp.array([1.0])) == 1.0 + assert a_bijectors["test_param"].inv(jnp.array([1.0])) == 1.0 diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index e6f86827..5def2d16 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -1,7 +1,7 @@ import typing as tp from mimetypes import init -import distrax as dx +import numpyro.distributions as npd import jax.numpy as jnp import jax.random as jr import pytest @@ -128,10 +128,10 @@ def test_variational_gaussians( assert isinstance(predictive_dist_fn, tp.Callable) predictive_dist = predictive_dist_fn(test_inputs) - assert isinstance(predictive_dist, dx.Distribution) + assert isinstance(predictive_dist, npd.Distribution) - mu = predictive_dist.mean() - sigma = predictive_dist.covariance() + mu = predictive_dist.mean + sigma = predictive_dist.covariance_matrix assert isinstance(mu, jnp.ndarray) assert isinstance(sigma, jnp.ndarray) @@ -194,10 +194,10 @@ def test_collapsed_variational_gaussian(n_test, n_inducing, n_datapoints, point_ assert isinstance(predictive_dist_fn, tp.Callable) predictive_dist = predictive_dist_fn(test_inputs) - assert isinstance(predictive_dist, dx.Distribution) + assert isinstance(predictive_dist, npd.Distribution) - mu = predictive_dist.mean() - sigma = predictive_dist.covariance() + mu = predictive_dist.mean + sigma = predictive_dist.covariance_matrix assert isinstance(mu, jnp.ndarray) assert isinstance(sigma, jnp.ndarray)