Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Natgrads #90

Merged
merged 67 commits into from
Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
5c33301
Initial commit.
daniel-dodd Aug 31, 2022
364b34e
Constrain, Unconstrain + Tests
daniel-dodd Aug 31, 2022
d1f4d0c
Parameter state (See comment)
daniel-dodd Aug 31, 2022
7eca9e7
Update nbs.
daniel-dodd Aug 31, 2022
e1d43ee
Add functionality to transform ParameterState
thomaspinder Sep 1, 2022
da52be4
Undo change
thomaspinder Sep 1, 2022
979546e
WIP for constrainers on state
thomaspinder Sep 1, 2022
8fcb114
Revert "WIP for constrainers on state"
daniel-dodd Sep 2, 2022
1fb92ee
Test MCMC docs.
daniel-dodd Sep 4, 2022
fe1cf90
Fix MCMC?
daniel-dodd Sep 7, 2022
5a01e63
Update classification.ipynb
daniel-dodd Sep 7, 2022
ff39d46
Update variational_inference.py
daniel-dodd Sep 16, 2022
ce79de4
Update classification.ipynb
daniel-dodd Sep 16, 2022
97f9d9f
Initial commit.
daniel-dodd Aug 31, 2022
2b617da
Constrain, Unconstrain + Tests
daniel-dodd Aug 31, 2022
6a17542
Parameter state (See comment)
daniel-dodd Aug 31, 2022
faca79a
Update nbs.
daniel-dodd Aug 31, 2022
dea684a
WIP for constrainers on state
thomaspinder Sep 1, 2022
d042a7b
Revert "WIP for constrainers on state"
daniel-dodd Sep 2, 2022
feb95c9
Test MCMC docs.
daniel-dodd Sep 4, 2022
d5e989b
Update classification.ipynb
daniel-dodd Sep 5, 2022
c358fc1
Update classification.ipynb
daniel-dodd Sep 7, 2022
3f5465b
Fix MCMC?
daniel-dodd Sep 7, 2022
8dbfc44
Update classification.ipynb
daniel-dodd Sep 7, 2022
dda0c6d
Natural gradients.
daniel-dodd Jun 1, 2022
dbaf04b
Update tests + fix minor bugs.
daniel-dodd Jun 1, 2022
048b0cc
Update variational_families.py
daniel-dodd Jun 1, 2022
40f9dd2
Update variational_families.py
daniel-dodd Jun 1, 2022
a0b5ab2
Update variational_families.py
daniel-dodd Jun 1, 2022
426e6cb
Update variational_families.py
daniel-dodd Jun 8, 2022
0604cdd
Update variational_families.py
daniel-dodd Jun 8, 2022
700acfb
Update test_variational_families.py
daniel-dodd Jun 8, 2022
a96281f
Natural gradients sketch.
daniel-dodd Jun 8, 2022
35173f8
Add notion of "moments"
daniel-dodd Jun 15, 2022
218ebf4
Add AbstractVariationalGaussian class.
daniel-dodd Jun 15, 2022
642cc73
Update test_variational_families.py
daniel-dodd Jun 15, 2022
347be8b
Update natural_gradients.py
daniel-dodd Jun 15, 2022
37bedd6
Update natural_gradients.py
daniel-dodd Jun 15, 2022
fa6ca2b
Update.
daniel-dodd Jun 15, 2022
02f068b
Minimal working natural gradient functions for NAT PARAMETERISATION.
daniel-dodd Jun 16, 2022
612f2c7
Add tests.
daniel-dodd Jun 16, 2022
25a7893
Update natural_gradients.py
daniel-dodd Jun 16, 2022
a5499d0
Update natural_gradients.py
daniel-dodd Jun 17, 2022
820c752
Add rough notebook.
daniel-dodd Jun 17, 2022
8bf35a8
Minimal working example complete.
daniel-dodd Jun 26, 2022
610338c
Update natgrads.ipynb
daniel-dodd Jun 26, 2022
c9ee548
Rebase master.
daniel-dodd Jul 20, 2022
6fe4439
Add development notebook for general
daniel-dodd Jul 21, 2022
20473e6
Update training loop.
daniel-dodd Aug 19, 2022
7f7c424
Update training loop. Add collapsed bound and natural gradient relati…
daniel-dodd Aug 19, 2022
48fcfa8
Update test_natural_gradients.py
daniel-dodd Aug 19, 2022
1e8a8e0
Fix variational families.
daniel-dodd Aug 23, 2022
c265b9f
Update training loop and notebook.
daniel-dodd Aug 23, 2022
5a9f232
Clean variational families.
daniel-dodd Aug 23, 2022
c01f733
Update typing.
daniel-dodd Aug 23, 2022
1cf8514
Address review comments.
daniel-dodd Aug 23, 2022
4833648
Address review (except notebook)..
daniel-dodd Aug 24, 2022
f45d99b
Address comments.
daniel-dodd Aug 24, 2022
8f4923d
Address review.
daniel-dodd Aug 25, 2022
bd64f3a
Create skeleton notebook.
daniel-dodd Aug 25, 2022
f8157d0
Move trainable dictionaries outside of gradient functions.
daniel-dodd Aug 26, 2022
9d92f5a
Update documentation.
daniel-dodd Aug 26, 2022
f1d3f81
Finish rebase issues.
daniel-dodd Sep 18, 2022
cc1c318
Add copyright, update typing.
daniel-dodd Sep 18, 2022
c873757
Merge branch 'v0.5_update' into natgrads
daniel-dodd Sep 20, 2022
566f654
Complete rebase.
daniel-dodd Sep 20, 2022
24032c0
One Cholesky is better than two.
daniel-dodd Sep 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,19 @@ @InProceedings{titsias2009
series = {Proceedings of Machine Learning Research},
publisher = {PMLR},
}

@misc{salimbeni2018,
doi = {10.48550/ARXIV.1803.09151},

url = {https://arxiv.org/abs/1803.09151},

author = {Salimbeni, Hugh and Eleftheriadis, Stefanos and Hensman, James},

keywords = {Machine Learning (stat.ML), Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences},

title = {Natural Gradients in Practice: Non-Conjugate Variational Inference in Gaussian Process Models},

publisher = {arXiv},

year = {2018},
}
1 change: 0 additions & 1 deletion examples/classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,6 @@
"\n",
"params, trainables, bijectors = gpx.initialise(posterior, key).unpack()\n",
"mll = posterior.marginal_log_likelihood(D, negative=False)\n",
"\n",
"unconstrained_mll = jax.jit(lambda params: mll(gpx.constrain(params, bijectors)))\n",
"\n",
"adapt = blackjax.window_adaptation(blackjax.nuts, unconstrained_mll, num_adapt, target_acceptance_rate=0.65)\n",
Expand Down
364 changes: 364 additions & 0 deletions examples/natgrads.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,364 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "98f89228",
"metadata": {},
"source": [
"# Natural Gradients:"
]
},
{
"cell_type": "markdown",
"id": "02dcd16f",
"metadata": {},
"source": [
"In this notebook, we show how to create natural gradients. Ordinary gradient descent algorithms are an undesirable for variational inference because we are minimising the KL divergence between distributions rather than a set of parameters directly. Natural gradients, on the other hand, accounts for the curvature induced by the KL divergence that has the capacity to considerably improve performance (see e.g., <strong data-cite=\"salimbeni2018\">Salimbeni et al. (2018)</strong> for further details)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "10376231",
"metadata": {},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"import matplotlib.pyplot as plt\n",
"from jax import jit, lax\n",
"import optax as ox\n",
"import gpjax as gpx\n",
"\n",
"key = jr.PRNGKey(123)"
]
},
{
"cell_type": "markdown",
"id": "3b851a25",
"metadata": {},
"source": [
"# Dataset:"
]
},
{
"cell_type": "markdown",
"id": "6f7facf2",
"metadata": {},
"source": [
"We simulate a dataset $\\mathcal{D} = (\\boldsymbol{x}, \\boldsymbol{y}) = \\{(x_i, y_i)\\}_{i=1}^{5000}$ with inputs $\\boldsymbol{x}$ sampled uniformly on $(-5, 5)$ and corresponding binary outputs\n",
"\n",
"$$\\boldsymbol{y} \\sim \\mathcal{N} \\left(\\sin(4 * \\boldsymbol{x}) + \\sin(2 * \\boldsymbol{x}), \\textbf{I} * (0.2)^{2} \\right).$$\n",
"\n",
"We store our data $\\mathcal{D}$ as a GPJax `Dataset` and create test inputs for later."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "39d6c8e6",
"metadata": {},
"outputs": [],
"source": [
"n = 5000\n",
"noise = 0.2\n",
"\n",
"x = jr.uniform(key=key, minval=-5.0, maxval=5.0, shape=(n,)).sort().reshape(-1, 1)\n",
"f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)\n",
"signal = f(x)\n",
"y = signal + jr.normal(key, shape=signal.shape) * noise\n",
"\n",
"D = gpx.Dataset(X=x, y=y)\n",
"xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1)"
]
},
{
"cell_type": "markdown",
"id": "af57fb31",
"metadata": {},
"source": [
"Intialise inducing points:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf6533b6",
"metadata": {},
"outputs": [],
"source": [
"z = jnp.linspace(-5.0, 5.0, 20).reshape(-1, 1)\n",
"\n",
"fig, ax = plt.subplots(figsize=(12, 5))\n",
"ax.plot(x, y, \"o\", alpha=0.3)\n",
"ax.plot(xtest, f(xtest))\n",
"[ax.axvline(x=z_i, color=\"black\", alpha=0.3, linewidth=1) for z_i in z]\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "664c204b",
"metadata": {},
"source": [
"# Natural gradients:"
]
},
{
"cell_type": "markdown",
"id": "ce4de494",
"metadata": {},
"source": [
"We begin by defining our model, variational family and variational inference strategy:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2284fbb2",
"metadata": {},
"outputs": [],
"source": [
"likelihood = gpx.Gaussian(num_datapoints=n)\n",
"kernel = gpx.RBF()\n",
"prior = gpx.Prior(kernel=kernel)\n",
"p = prior * likelihood\n",
"\n",
"\n",
"natural_q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)\n",
"natural_svgp = gpx.StochasticVI(posterior=p, variational_family=natural_q)\n",
"\n",
"parameter_state = gpx.initialise(natural_svgp)"
]
},
{
"cell_type": "markdown",
"id": "e793c24f",
"metadata": {},
"source": [
"Next, we can conduct natural gradients as follows:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7e9884f2",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"inference_state = gpx.fit_natgrads(natural_svgp,\n",
" parameter_state=parameter_state,\n",
" train_data = D,\n",
" n_iters = 5000,\n",
" batch_size=100,\n",
" key = jr.PRNGKey(42),\n",
" moment_optim = ox.sgd(1.0),\n",
" hyper_optim = ox.adam(1e-3),\n",
" )\n",
"\n",
"learned_params, training_history = inference_state.unpack()"
]
},
{
"cell_type": "markdown",
"id": "fbcdd41c",
"metadata": {},
"source": [
"Here is the fitted model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cff40778",
"metadata": {},
"outputs": [],
"source": [
"latent_dist = natural_q(learned_params)(xtest)\n",
"predictive_dist = likelihood(latent_dist, learned_params)\n",
"\n",
"meanf = predictive_dist.mean()\n",
"sigma = predictive_dist.stddev()\n",
"\n",
"fig, ax = plt.subplots(figsize=(12, 5))\n",
"ax.plot(x, y, \"o\", alpha=0.15, label=\"Training Data\", color=\"tab:gray\")\n",
"ax.plot(xtest, meanf, label=\"Posterior mean\", color=\"tab:blue\")\n",
"ax.fill_between(xtest.flatten(), meanf - sigma, meanf + sigma, alpha=0.3)\n",
"[\n",
" ax.axvline(x=z_i, color=\"black\", alpha=0.3, linewidth=1)\n",
" for z_i in learned_params[\"variational_family\"][\"inducing_inputs\"]\n",
"]\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "5db1e2e3",
"metadata": {},
"source": [
"# Natural gradients and sparse varational Gaussian process regression:"
]
},
{
"cell_type": "markdown",
"id": "649d29ec",
"metadata": {},
"source": [
"As mentioned in <strong data-cite=\"hensman2013gaussian\">Hensman et al. (2013)</strong>, in the case of a Gaussian likelihood, taking a step of unit length for natural gradients on a full batch of data recovers the same solution as <strong data-cite=\"titsias2009\">Titsias (2009)</strong>. We now illustrate this."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0995d1f2",
"metadata": {},
"outputs": [],
"source": [
"n = 1000\n",
"noise = 0.2\n",
"\n",
"x = jr.uniform(key=key, minval=-5.0, maxval=5.0, shape=(n,)).sort().reshape(-1, 1)\n",
"f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)\n",
"signal = f(x)\n",
"y = signal + jr.normal(key, shape=signal.shape) * noise\n",
"\n",
"D = gpx.Dataset(X=x, y=y)\n",
"\n",
"xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ec554e0",
"metadata": {},
"outputs": [],
"source": [
"z = jnp.linspace(-5.0, 5.0, 20).reshape(-1, 1)\n",
"\n",
"fig, ax = plt.subplots(figsize=(12, 5))\n",
"ax.plot(x, y, \"o\", alpha=0.3)\n",
"ax.plot(xtest, f(xtest))\n",
"[ax.axvline(x=z_i, color=\"black\", alpha=0.3, linewidth=1) for z_i in z]\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eee8115a",
"metadata": {},
"outputs": [],
"source": [
"likelihood = gpx.Gaussian(num_datapoints=n)\n",
"kernel = gpx.RBF()\n",
"prior = gpx.Prior(kernel=kernel)\n",
"p = prior * likelihood"
]
},
{
"cell_type": "markdown",
"id": "6640c071",
"metadata": {},
"source": [
"We begin with natgrads:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "078e03c4",
"metadata": {},
"outputs": [],
"source": [
"from gpjax.natural_gradients import natural_gradients\n",
"\n",
"q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)\n",
"svgp = gpx.StochasticVI(posterior=p, variational_family=q)\n",
"params, trainables, bijectors = gpx.initialise(svgp).unpack()\n",
"\n",
"params = gpx.unconstrain(params, bijectors)\n",
"\n",
"nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, bijectors, trainables)\n",
"\n",
"moment_optim = ox.sgd(1.0)\n",
"\n",
"moment_state = moment_optim.init(params)\n",
"\n",
"# Natural gradients update:\n",
"loss_val, loss_gradient = nat_grads_fn(params, D)\n",
"print(loss_val)\n",
"\n",
"updates, moment_state = moment_optim.update(loss_gradient, moment_state, params)\n",
"params = ox.apply_updates(params, updates)\n",
"\n",
"loss_val, _ = nat_grads_fn(params, D)\n",
"\n",
"print(loss_val)"
]
},
{
"cell_type": "markdown",
"id": "c7c16824",
"metadata": {},
"source": [
"Let us now run it for SGPR:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6482af13",
"metadata": {},
"outputs": [],
"source": [
"q = gpx.CollapsedVariationalGaussian(prior=prior, likelihood=likelihood, inducing_inputs=z)\n",
"sgpr = gpx.CollapsedVI(posterior=p, variational_family=q)\n",
"\n",
"params, _, _ = gpx.initialise(svgp).unpack()\n",
"\n",
"loss_fn = sgpr.elbo(D, negative=True)\n",
"\n",
"loss_val = loss_fn(params)\n",
"\n",
"print(loss_val)"
]
},
{
"cell_type": "markdown",
"id": "bdae1c03",
"metadata": {},
"source": [
"The discrepancy is due to the quadrature approximation."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.0 ('base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
},
"vscode": {
"interpreter": {
"hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading