Skip to content

Commit

Permalink
adapt examples markdown cells
Browse files Browse the repository at this point in the history
  • Loading branch information
frazane committed Aug 24, 2023
1 parent 4d1e4ae commit 76af7d8
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@

# %% [markdown]
# We can obtain a MAP estimate by optimising the log-posterior density with
# Optax's optimisers.
# `jaxopt` solvers.

# %%
negative_lpd = jax.jit(gpx.LogPosteriorDensity(negative=True))
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/deep_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ def __call__(self, x):
# hyperparameter set.
#
# With the inclusion of a neural network, we take this opportunity to highlight the
# additional benefits gleaned from using
# [Optax](https://optax.readthedocs.io/en/latest/) for optimisation. In particular, we
# additional benefits gleaned from using `jaxopt`'s
# [Optax](https://optax.readthedocs.io/en/latest/) solver for optimisation. In particular, we
# showcase the ability to use a learning rate scheduler that decays the optimiser's
# learning rate throughout the inference. We decrease the learning rate according to a
# half-cosine curve over 700 iterations, providing us with large step sizes early in
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@
# For this reason, we simply perform gradient descent on the GP's marginal
# log-likelihood term as in the
# [regression notebook](https://docs.jaxgaussianprocesses.com/examples/regression/).
# We do this using the Adam optimiser provided in `optax`.
# We do this using the OptaxSolver provided by `jaxopt`, instantiated with the Adam optimiser.

# %%
likelihood = gpx.Gaussian(num_datapoints=D.n)
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@
# accelerate training.

# %% [markdown]
# We can now define an optimiser with `optax`. For this example we'll use the `adam`
# optimiser.
# We can now train our model using a `jaxopt` solver. In this case we opt for the `OptaxSolver`,
# which wraps an `optax` optimizer.

# %%
opt_posterior, history = gpx.fit(
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/uncollapsed_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@
# see Sections 3.1 and 4.1 of the excellent review paper
# <strong data-cite="leibfried2020tutorial"></strong>.
#
# Since Optax's optimisers work to minimise functions, to maximise the ELBO we return
# Since `jaxopt's solvers work to minimise functions, to maximise the ELBO we return
# its negative.

# %%
Expand Down
3 changes: 2 additions & 1 deletion docs/examples/yacht.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@
# ### Model Optimisation
#
# With a model now defined, we can proceed to optimise the hyperparameters of our
# model using Optax.
# model using one of `jaxopt`'s solvers. In this case we use a solver that wraps an
# `optax` optimizer.

# %%
training_data = gpx.Dataset(X=scaled_Xtr, y=scaled_ytr)
Expand Down

0 comments on commit 76af7d8

Please sign in to comment.