Skip to content

Commit

Permalink
fix graph kernel notebook plot
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-dodd committed Oct 27, 2023
1 parent 88a9667 commit 392b3da
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
4 changes: 2 additions & 2 deletions docs/examples/collapsed_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,11 @@
full_rank_model = gpx.Prior(mean_function=gpx.Zero(), kernel=gpx.RBF()) * gpx.Gaussian(
num_datapoints=D.n
)
negative_mll = jit(gpx.ConjugateMLL(negative=True))
negative_mll = jit(gpx.ConjugateMLL(negative=True).step)
# %timeit negative_mll(full_rank_model, D).block_until_ready()

# %%
negative_elbo = jit(gpx.CollapsedELBO(negative=True))
negative_elbo = jit(gpx.CollapsedELBO(negative=True).step)
# %timeit negative_elbo(q, D).block_until_ready()

# %% [markdown]
Expand Down
8 changes: 5 additions & 3 deletions docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
# %% [markdown]
# # Graph Kernels
#
Expand Down Expand Up @@ -119,7 +120,8 @@
cmap=plt.cm.inferno, norm=plt.Normalize(vmin=vmin, vmax=vmax)
)
sm.set_array([])
cbar = plt.colorbar(sm)
ax = plt.gca()
cbar = plt.colorbar(sm, ax=ax)

# %% [markdown]
#
Expand Down Expand Up @@ -201,8 +203,8 @@
sm = plt.cm.ScalarMappable(
cmap=plt.cm.inferno, norm=plt.Normalize(vmin=vmin, vmax=vmax)
)
sm.set_array([])
cbar = plt.colorbar(sm)
ax = plt.gca()
cbar = plt.colorbar(sm, ax=ax)

# %% [markdown]
#
Expand Down

0 comments on commit 392b3da

Please sign in to comment.