Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix latex rending + minor typos. #257

Merged
merged 1 commit into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
# ## Dataset
#
# With the necessary modules imported, we simulate a dataset
# $\mathcal{D} = (, \boldsymbol{y}) = \{(x_i, y_i)\}_{i=1}^{100}$ with inputs
# $\mathcal{D} = (\boldsymbol{x}, \boldsymbol{y}) = \{(x_i, y_i)\}_{i=1}^{100}$ with inputs
# $\boldsymbol{x}$ sampled uniformly on $(-1., 1)$ and corresponding binary outputs
#
# $$\boldsymbol{y} = 0.5 * \text{sign}(\cos(2 * + \boldsymbol{\epsilon})) + 0.5, \quad \boldsymbol{\epsilon} \sim \mathcal{N} \left(\textbf{0}, \textbf{I} * (0.05)^{2} \right).$$
Expand Down
34 changes: 19 additions & 15 deletions docs/examples/pytrees.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,17 @@ The kernel in a Gaussian process model is a mathematical function that
defines the covariance structure between data points, allowing us to model
complex relationships and make predictions based on the observed data. The
radial basis function (RBF, or _squared exponential_) kernel is a popular
choice. For any pair of vectors $x, y \in \mathbb{R}^d$, its form is
given by $$ k(x, y) = \sigma^2\exp\left(\frac{\lVert
x-y\rVert_{2}^2}{2\ell^2} \right) $$ where $\sigma^2\in\mathbb{R}_{>0}$ is a
variance parameter and $\ell^2\in\mathbb{R}_{>0}$ a lengthscale parameter.
Terming the evaluation of $k(x, y)$ the _covariance_, we can represent
choice. For any pair of vectors $`x, y \in \mathbb{R}^d`$, its form is
given by

```math
k(x, y) = \sigma^2\exp\left(\frac{\lVert
x-y\rVert_{2}^2}{2\ell^2} \right)
```

where $`\sigma^2\in\mathbb{R}_{>0}`$ is a
variance parameter and $`\ell^2\in\mathbb{R}_{>0}`$ a lengthscale parameter.
Terming the evaluation of $`k(x, y)`$ the _covariance_, we can represent
this object as a Python `dataclass` as follows:


Expand Down Expand Up @@ -80,8 +86,8 @@ class RBF:
To establish some terminology, within the above RBF `dataclass`, we refer to
the lengthscale and variance as _fields_. Further, the `RBF.covariance` is a
_method_. So far so good. However, if we wanted to take the gradient of
the kernel with respect to its parameters $\nabla_{\ell, \sigma^2} k(1.0, 2.0;
\ell, \sigma^2)$ at inputs $x=1.0$ and $y=2.0$, then we encounter a problem:
the kernel with respect to its parameters $`\nabla_{\ell, \sigma^2} k(1.0, 2.0;
\ell, \sigma^2)`$ at inputs $`x=1.0`$ and $`y=2.0`$, then we encounter a problem:

```python
kernel = RBF()
Expand Down Expand Up @@ -203,8 +209,7 @@ print(gradient)
This computes the gradient of the `sum_squares` function with respect to the
input PyTree, and returns a new PyTree with the same shape and structure.

<!-- #region JAX PyTrees are also designed to be highly extensible, where -->
custom types can be readily registered through a global registry with the
JAX PyTrees are also designed to be highly extensible, where custom types can be readily registered through a global registry with the
values of such traversed recursively (i.e., as a tree!). This means we can
define our own custom data structures and use them as PyTrees. This is the
functionality that we exploit, whereby we construct all Gaussian process
Expand Down Expand Up @@ -317,8 +322,8 @@ RBF(lengthscale=Array(3.14, dtype=float32), variance=Array(1., dtype=float32))
## Trainability 🚂

Recall the example earlier, where we wanted to take the gradient of the kernel
with repsect to its parameters $\nabla_{\ell, \sigma^2} k(1.0, 2.0; \ell,
\sigma^2)$ at inputs $x=1.0$ and $y=2.0$. We can now confirm we can do this
with repsect to its parameters $`\nabla_{\ell, \sigma^2} k(1.0, 2.0; \ell,
\sigma^2)`$ at inputs $`x=1.0`$ and $`y=2.0`$. We can now confirm we can do this
with the new `Module`.

```python
Expand All @@ -334,7 +339,7 @@ During gradient learning of models, it can sometimes be useful to fix certain
parameters during the optimisation routine. For this, JAX provides a
`stop_gradient` operand to prevent the flow of gradients during forward or
reverse-mode automatic differentiation, as illustrated below for a function
$f(x) = x^2$.
$`f(x) = x^2`$.

```python
from jax import lax
Expand Down Expand Up @@ -579,8 +584,7 @@ transformed PyTree, as demonstrated in the examples that follow.

### Filter example:

<!-- #region A `meta_map` works similarly to `jax.tree_utils.tree_map`. -->
However, it differs in that it allows us to define a function that operates on
A `meta_map` works similarly to `jax.tree_utils.tree_map`. However, it differs in that it allows us to define a function that operates on
the tuple (metadata, leaf value). For example, we could use a function to
filter based on a `name` attribute.

Expand Down Expand Up @@ -610,7 +614,7 @@ To apply a constrain, we filter on the attribute "bijector", and apply a
forward transformation to the PyTree leaf:

```python
# This is how constrain works.
# This is how constrain works! ⛏
def _apply_constrain(meta_leaf):
meta, leaf = meta_leaf

Expand Down