Skip to content

Commit

Permalink
Update parameters transform and backend doc
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaspinder committed Aug 16, 2024
1 parent 7e628f0 commit 172e61e
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions examples/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@

# %% [markdown]
# However, suppose you wish your mean function's constant parameter to be strictly
# positive. This is easy to achieve by using the correct Parameter type which, in this case, will be the `PositiveReal`. However, any Parameter that subclasses from `Parameter` will be transformed by GPJax.
# positive. This is easy to achieve by using the correct Parameter type which, in this
# case, will be the `PositiveReal`. However, any Parameter that subclasses from
# `Parameter` will be transformed by GPJax.

# %%
from gpjax.parameters import PositiveReal
Expand Down Expand Up @@ -134,7 +136,8 @@

# %% [markdown]
# We see here that the Softplus bijector is specified as the default for strictly
# positive parameters. To apply this, we must first realise the _state_ of our model. This is achieved using the `split` function provided by `nnx`.
# positive parameters. To apply this, we must first realise the _state_ of our model.
# This is achieved using the `split` function provided by `nnx`.

# %%
_, _params = nnx.split(meanf, Parameter)
Expand Down Expand Up @@ -341,7 +344,11 @@ def loss_fn(params: nnx.State, data: gpx.Dataset) -> ScalarFloat:
param_grads = grad(loss_fn)(params, D)

# %% [markdown]
# In practice, you would wish to perform multiple iterations of gradient descent to learn the optimal parameter values. However, for the purposes of illustration, we use another `tree_map` in the below to update the parameters' state using their previously computed gradients. As you can see, the really beauty in having access to the model's state is that we have full control over the operations that we perform to the state.
# In practice, you would wish to perform multiple iterations of gradient descent to
# learn the optimal parameter values. However, for the purposes of illustration, we use
# another `tree_map` in the below to update the parameters' state using their previously
# computed gradients. As you can see, the really beauty in having access to the model's
# state is that we have full control over the operations that we perform to the state.

# %%
LEARNING_RATE = 0.01
Expand All @@ -350,7 +357,9 @@ def loss_fn(params: nnx.State, data: gpx.Dataset) -> ScalarFloat:
)

# %% [markdown]
# Now we will plot the updated mean function alongside its initial form. To achieve this, we first merge the state back into the model using `merge`, and we then simply invoke the model as normal.
# Now we will plot the updated mean function alongside its initial form. To achieve
# this, we first merge the state back into the model using `merge`, and we then simply
# invoke the model as normal.

# %%
optimised_posterior = nnx.merge(graphdef, optimised_params, *others)
Expand Down

0 comments on commit 172e61e

Please sign in to comment.