From c7b3b1b4c7d4bf51f596cd99789dfc6bd3d44883 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Thu, 18 May 2023 21:17:33 +0100 Subject: [PATCH] Resolve comments --- docs/_static/bijector_figure.svg | 813 ++++++++++++++++++++++++++++++ docs/_static/step_size_figure.svg | 559 ++++++++++++++++++++ docs/refs.bib | 12 + docs/scripts/sharp_bits_figure.py | 93 ++++ docs/sharp_bits.md | 62 ++- 5 files changed, 1515 insertions(+), 24 deletions(-) create mode 100644 docs/_static/bijector_figure.svg create mode 100644 docs/_static/step_size_figure.svg create mode 100644 docs/scripts/sharp_bits_figure.py diff --git a/docs/_static/bijector_figure.svg b/docs/_static/bijector_figure.svg new file mode 100644 index 00000000..5739003e --- /dev/null +++ b/docs/_static/bijector_figure.svg @@ -0,0 +1,813 @@ + + + + + + + + 2023-05-18T20:58:11.814177 + image/svg+xml + + + Matplotlib v3.7.1, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/_static/step_size_figure.svg b/docs/_static/step_size_figure.svg new file mode 100644 index 00000000..f9485927 --- /dev/null +++ b/docs/_static/step_size_figure.svg @@ -0,0 +1,559 @@ + + + + + + + + 2023-05-18T20:20:35.450666 + image/svg+xml + + + Matplotlib v3.7.1, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/refs.bib b/docs/refs.bib index 6c62450c..1889d188 100644 --- a/docs/refs.bib +++ b/docs/refs.bib @@ -112,3 +112,15 @@ @inproceedings{wilson2020efficient numpages = {11}, series = {ICML'20} } + +@book{higham2022accuracy, + author = {Higham, Nicholas J.}, + title = {Accuracy and Stability of Numerical Algorithms}, + publisher = {Society for Industrial and Applied Mathematics}, + year = {2002}, + doi = {10.1137/1.9780898718027}, + address = {}, + edition = {Second}, + url = {https://epubs.siam.org/doi/abs/10.1137/1.9780898718027}, + eprint = {https://epubs.siam.org/doi/pdf/10.1137/1.9780898718027} +} diff --git a/docs/scripts/sharp_bits_figure.py b/docs/scripts/sharp_bits_figure.py new file mode 100644 index 00000000..dce70a8f --- /dev/null +++ b/docs/scripts/sharp_bits_figure.py @@ -0,0 +1,93 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: gpjax_baselines +# language: python +# name: python3 +# --- + +# %% +import numpy as np +import matplotlib.pyplot as plt +import matplotlib as mpl +import matplotlib.patches as patches + +plt.style.use("../examples/gpjax.mplstyle") +cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] + +# %% +fig, ax = plt.subplots() +ax.axhline(y = 0.25, color=cols[0], linewidth=1.5) + +xs = [0.02, 0.06, 0.1, 0.17] +ys = np.ones_like(xs) * 0.25 + +ax.scatter(xs, ys, color=cols[1], marker="o", s=100, zorder=2) + +for idx, x in enumerate(xs): + ax.annotate(text = f'$\ell_{{t-{idx+1}}}$', xy=(x, 0.25), xytext=(x+0.01, 0.275), ha='center', va='bottom') + + +style = "Simple, tail_width=0.5, head_width=4, head_length=8" +kw = dict(arrowstyle=style, color="k") + +for i in range(len(xs)-1): + a = patches.FancyArrowPatch((xs[i+1], 0.25), (xs[i], 0.25), connectionstyle="arc3,rad=-.5", **kw) + ax.add_patch(a) + + +ax.scatter(-0.03, 0.25, color=cols[1], marker="x", s=100, linewidth=5, zorder=2) + +a = patches.FancyArrowPatch((xs[0], 0.25), (-0.03, 0.25), connectionstyle="arc3,rad=-.5", **kw) +ax.add_patch(a) + +ax.axvline(x = 0, color='black', linewidth=0.5, linestyle='-.') +ax.get_yaxis().set_visible(False) +ax.spines["left"].set_visible(False) +ax.set_ylim(0., 0.5) +ax.set_xlim(-0.07, 0.25) +plt.savefig('../_static/step_size_figure.svg', bbox_inches='tight') + +# %% +import tensorflow_probability.substrates.jax.bijectors as tfb +import jax.numpy as jnp + +bij = tfb.Exp() + +x = np.linspace(0.05, 3., 6) +y = np.asarray(bij.inverse(x)) +lval = 0.5 +rval = 0.52 + +fig, ax = plt.subplots() +ax.scatter(x, np.ones_like(x)*lval, s=100, label='Constrained value') +ax.scatter(y, np.ones_like(y)*rval, marker='o', s=100, label='Unconstrained value') + +style = "Simple, tail_width=0.25, head_width=2, head_length=8" +for i in range(len(x)): + if i%2 != 0: + a = patches.FancyArrowPatch((x[i], lval), (y[i], rval), connectionstyle="arc3,rad=-.15", **kw) + # a = patches.Arrow(lval, x[i], rval-lval, y[i]-x[i], width=0.05, color='k') + else: + a = patches.FancyArrowPatch((x[i], lval), (y[i], rval), connectionstyle="arc3,rad=.005", **kw) + ax.add_patch(a) + +ax.get_yaxis().set_visible(False) +ax.spines["left"].set_visible(False) +ax.legend(loc='best') +# ax.set_ylim(0.1, 0.32) +plt.savefig('../_static/bijector_figure.svg', bbox_inches='tight') + +# %% +np.log(0.05) + +# %% +x diff --git a/docs/sharp_bits.md b/docs/sharp_bits.md index c65e89a7..15153e6e 100644 --- a/docs/sharp_bits.md +++ b/docs/sharp_bits.md @@ -2,9 +2,9 @@ ## Pseudo-randomness -Libraries like Numpy and Scipy use *stateful* pseudorandom number generators (PRNGs). +Libraries like NumPy and Scipy use *stateful* pseudorandom number generators (PRNGs). However, the PRNG in JAX is stateless. This means that for a given function, the -return always return the same result unless the seed is changed. This is a good thing, +return always returns the same result unless the seed is changed. This is a good thing, but it means that we need to be careful when using JAX's PRNGs. To examine what it means for a PRNG to be stateful, consider the following example: @@ -14,8 +14,8 @@ import numpy as np import jax.random as jr key = jr.PRNGKey(123) -# Numpy -print('Numpy:') +# NumPy +print('NumPy:') print(np.random.random()) print(np.random.random()) @@ -28,7 +28,7 @@ key, subkey = jr.split(key) print(jr.uniform(subkey)) ``` ```console -Numpy: +NumPy: 0.5194454541172852 0.9815886617924413 @@ -39,9 +39,9 @@ JAX: Splitting key 0.23886406 ``` -We can see that, in libraries like Numpy, the PRNG key's state is incremented whenever +We can see that, in libraries like NumPy, the PRNG key's state is incremented whenever a pseudorandom call is made. This can make debugging difficult to manage as it is not -always clear when a PRNG is being used. In JAX, the PRNG key is not incremented, and +always clear when a PRNG is being used. In JAX, the PRNG key is not incremented, so the same key will always return the same result. This has further positive benefits for reproducibility. @@ -53,12 +53,16 @@ Parameters such as the kernel's lengthscale or variance have their support defin a constrained subset of the real-line. During gradient-based optimisation, as we approach the set's boundary, it becomes possible that we could step outside of the set's support and introduce a numerical and mathematical error into our model. For -example, consider the variance parameter $`\sigma^2`$, which we know must be strictly -positive. If at $`t^{\text{th}}`$ iterate, our current estimate of $`\sigma^2`$ was -0.03 and our derivative informed us that $`\sigma^2`$ should decrease, then if our -learning rate is greater is than 0.03, we would end up with a negative variance term. +example, consider the lengthscale parameter $`\ell`$, which we know must be strictly +positive. If at $`t^{\text{th}}`$ iterate, our current estimate of $`\ell`$ was +0.02 and our derivative informed us that $`\ell`$ should decrease, then if our +learning rate is greater is than 0.03, we would end up with a negative variance term. +We visualise this issue below where the red cross denotes the invalid lengthscale value +that would be obtained, were we to optimise in the unconstrained parameter space. -A simple, but impractical solution, would be to use a tiny learning rate which would +![](_static/step_size_figure.svg) + +A simple but impractical solution would be to use a tiny learning rate which would reduce the possibility of stepping outside of the parameter's support. However, this would be incredibly costly and does not eradicate the problem. An alternative solution is to apply a functional mapping to the parameter that projects it from a constrained @@ -66,6 +70,15 @@ subspace of the real-line onto the entire real-line. Here, gradient updates are applied in the unconstrained parameter space before transforming the value back to the original support of the parameters. Such a transformation is known as a bijection. +![](_static/bijector_figure.svg) + +To help understand this, we show the effect of using a log-exp bijector in the above +figure. We have six points on the positive real line that range from 0.1 to 3 depicted +by a blue cross. We then apply the bijector by log-transforming the constrained value. +This gives us the points' unconstrained value which we depict by a red circle. It is +this value that we apply gradient updates to. When we wish to recover the constrained +value, we apply the inverse of the bijector, which is the exponential function in this +case. This gives us back the blue cross. In GPJax, we supply bijective functions using [Tensorflow Probability](https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors). In our [PyTrees doc](examples/pytrees.md) document, we detail how the user can define @@ -73,41 +86,42 @@ their own bijectors and attach them to the parameter(s) of their model. ## Positive-definiteness -> "Symmetric positive definiteness is one of the highest accolades to which a matrix can aspire" - Nicholas Highman, Accuracy and stability of numerical algorithms +> "Symmetric positive definiteness is one of the highest accolades to which a matrix can aspire" - Nicholas Highman, Accuracy and stability of numerical algorithms [@higham2022accuracy] ### Why is positive-definiteness important? -The covariance matrix of a kernel is a symmetric positive definite matrix. As such, we +The Gram matrix of a kernel, a concept that we explore more in our +[kernels notebook](examples/kernels.py) and our [PyTree notebook](examples/pytrees.md), is a +symmetric positive definite matrix. As such, we have a range of tools at our disposal to make subsequent operations on the covariance matrix faster. One of these tools is the Cholesky factorisation that uniquely decomposes any symmetric positive-definite matrix $`\mathbf{\Sigma}`$ by ```math \begin{align} - \mathbf{\Sigma} = \mathbf{L}\mathbf{L}^{\top} + \mathbf{\Sigma} = \mathbf{L}\mathbf{L}^{\top}\,, \end{align} ``` - where $`\mathbf{L}`$ is a lower triangular matrix. We make use of this result in GPJax when solving linear systems of equations of the form $`\mathbf{A}\boldsymbol{x} = \boldsymbol{b}`$. Whilst seemingly abstract at first, such problems are frequently encountered when constructing Gaussian process models. One such example is frequently encountered in the regression setting for learning Gaussian -process Kernel hyperparameters. Here we have labels -$`\boldsymbol{y} \sim \mathcal{N}(f(\boldsymbol{x}), \mathbf{\Sigma})`$ with $`f(\boldsymbol{x}) \sim \mathcal{N}(\boldsymbol{0}, \mathbf{K}_{\boldsymbol{xx}})`$ arising from zero-mean -Gaussian process prior and gram matrix $`\mathbf{K}_{\boldsymbol{xx}}`$ at the inputs +process kernel hyperparameters. Here we have labels +$`\boldsymbol{y} \sim \mathcal{N}(f(\boldsymbol{x}), \sigma^2\mathbf{I})`$ with $`f(\boldsymbol{x}) \sim \mathcal{N}(\boldsymbol{0}, \mathbf{K}_{\boldsymbol{xx}})`$ arising from zero-mean +Gaussian process prior and Gram matrix $`\mathbf{K}_{\boldsymbol{xx}}`$ at the inputs $`\boldsymbol{x}`$. Here the marginal log-likelihood comprises the following form ```math \begin{align} - \log p(\boldsymbol{y}) = 0.5\left(-\boldsymbol{y}^{\top}\left(\mathbf{K}_{\boldsymbol{xx}} + \sigma^2\mathbf{I} \right)^{-1}\boldsymbol{y} -\log\lvert \mathbf{K}_{\boldsymbol{xx}} + \mathbf{\Sigma}\rvert -n\log(2\pi)\right) , + \log p(\boldsymbol{y}) = 0.5\left(-\boldsymbol{y}^{\top}\left(\mathbf{K}_{\boldsymbol{xx}} + \sigma^2\mathbf{I} \right)^{-1}\boldsymbol{y} -\log\lvert \mathbf{K}_{\boldsymbol{xx}} + \sigma^2\mathbf{I}\rvert -n\log(2\pi)\right) , \end{align} ``` -and the goal of inference is to maximise kernel hyperparameters (contained in the gram +and the goal of inference is to maximise kernel hyperparameters (contained in the Gram matrix $`\mathbf{K}_{\boldsymbol{xx}}`$) and likelihood hyperparameters (contained in the -noise covariance $`\mathbf{\Sigma}`$). Computing the marginal log-likelihood (and its +noise covariance $`\sigma^2\mathbf{I}`$). Computing the marginal log-likelihood (and its gradients), draws our attention to the term ```math @@ -131,7 +145,7 @@ While the computational acceleration provided by using Cholesky factors instead matrices is hopefully now apparent, an awkward numerical instability _gotcha_ can arise due to floating-point rounding errors. When we evaluate a covariance function on a set of points that are very _close_ to one another, eigenvalues of the corresponding -covariance matrix can get very small. So small that after numerical rounding, the +Gram matrix can get very small. So small that after numerical rounding, the smallest eigenvalues can become negative-valued. While not truly less than zero, our computer thinks they are, which becomes a problem when we want to compute a Cholesky factor since this requires that the input matrix is positive-definite. If there are @@ -145,7 +159,7 @@ for some problems, this amount may need to be increased. Famously, a regular Gaussian process model (as detailed in [our regression notebook](examples/regression.py)) will scale cubically in the number of data points. -Consequently, if you try to fit your Gaussian process model to data set containing more +Consequently, if you try to fit your Gaussian process model to a data set containing more than several thousand data points, then you will likely incur a significant computational overhead. In such cases, we recommend using Sparse Gaussian processes to alleviate this issue.