diff --git a/docs/examples/collapsed_vi.py b/docs/examples/collapsed_vi.py index 8dd442e3..80e0436a 100644 --- a/docs/examples/collapsed_vi.py +++ b/docs/examples/collapsed_vi.py @@ -244,11 +244,11 @@ full_rank_model = gpx.Prior(mean_function=gpx.Zero(), kernel=gpx.RBF()) * gpx.Gaussian( num_datapoints=D.n ) -negative_mll = jit(gpx.ConjugateMLL(negative=True)) +negative_mll = jit(gpx.ConjugateMLL(negative=True).step) # %timeit negative_mll(full_rank_model, D).block_until_ready() # %% -negative_elbo = jit(gpx.CollapsedELBO(negative=True)) +negative_elbo = jit(gpx.CollapsedELBO(negative=True).step) # %timeit negative_elbo(q, D).block_until_ready() # %% [markdown] diff --git a/docs/examples/graph_kernels.py b/docs/examples/graph_kernels.py index 82154b3a..2d77d79e 100644 --- a/docs/examples/graph_kernels.py +++ b/docs/examples/graph_kernels.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # %% [markdown] # # Graph Kernels # @@ -119,7 +120,8 @@ cmap=plt.cm.inferno, norm=plt.Normalize(vmin=vmin, vmax=vmax) ) sm.set_array([]) -cbar = plt.colorbar(sm) +ax = plt.gca() +cbar = plt.colorbar(sm, ax=ax) # %% [markdown] # @@ -201,8 +203,8 @@ sm = plt.cm.ScalarMappable( cmap=plt.cm.inferno, norm=plt.Normalize(vmin=vmin, vmax=vmax) ) -sm.set_array([]) -cbar = plt.colorbar(sm) +ax = plt.gca() +cbar = plt.colorbar(sm, ax=ax) # %% [markdown] #