Skip to content

Commit

Permalink
lint + format all doc examples
Browse files Browse the repository at this point in the history
  • Loading branch information
frazane committed Mar 8, 2024
1 parent 8782915 commit d52530e
Show file tree
Hide file tree
Showing 11 changed files with 14 additions and 23 deletions.
3 changes: 1 addition & 2 deletions docs/examples/barycentres.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import jax.scipy.linalg as jsl
from jaxtyping import install_import_hook
import matplotlib.pyplot as plt
import optax as ox
import tensorflow_probability.substrates.jax.distributions as tfd

with install_import_hook("gpjax", "beartype.beartype"):
Expand Down Expand Up @@ -102,7 +101,7 @@
f = lambda x, a, b: a + jnp.sin(b * x)

ys = []
for _i in range(n_datasets):
for _ in range(n_datasets):
key, subkey = jr.split(key)
vertical_shift = jr.uniform(subkey, minval=0.0, maxval=2.0)
period = jr.uniform(subkey, minval=0.75, maxval=1.25)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/bayesian_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def plot_bayes_opt(
initial_y = standardised_forrester(initial_x)
D = gpx.Dataset(X=initial_x, y=initial_y)

for i in range(bo_iters):
for _ in range(bo_iters):
key, subkey = jr.split(key)

# Generate optimised posterior using previously observed data
Expand Down
1 change: 0 additions & 1 deletion docs/examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
)
import matplotlib.pyplot as plt
import optax as ox
from flax.experimental import nnx
import tensorflow_probability.substrates.jax as tfp
from tqdm import trange

Expand Down
4 changes: 0 additions & 4 deletions docs/examples/constructing_new_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
config.update("jax_enable_x64", True)

from dataclasses import dataclass
from typing import Dict

from jax import jit
import jax.numpy as jnp
Expand All @@ -39,13 +38,10 @@
install_import_hook,
)
import matplotlib.pyplot as plt
import numpy as np
from simple_pytree import static_field
import tensorflow_probability.substrates.jax as tfp

with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
from gpjax.base.param import param_field

key = jr.PRNGKey(123)
tfb = tfp.bijectors
Expand Down
5 changes: 0 additions & 5 deletions docs/examples/deep_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,10 @@
import matplotlib.pyplot as plt
import optax as ox
from scipy.signal import sawtooth
from gpjax.base import static_field

with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
from gpjax.base import param_field
import gpjax.kernels as jk
from gpjax.kernels import DenseKernelComputation
from gpjax.kernels.base import AbstractKernel
from gpjax.kernels.computations import AbstractKernelComputation

key = jr.PRNGKey(123)
plt.style.use(
Expand Down
2 changes: 0 additions & 2 deletions docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@

import random

from jax import jit
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook
import matplotlib as mpl
import matplotlib.pyplot as plt
import networkx as nx
import optax as ox

with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
Expand Down
14 changes: 10 additions & 4 deletions docs/examples/intro_to_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import matplotlib.pyplot as plt
import optax as ox
import pandas as pd
from docs.examples.utils import clean_legend

with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
Expand Down Expand Up @@ -197,7 +196,7 @@

# %%
# Forrester function
def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]: # noqa: F821
return (6 * x - 2) ** 2 * jnp.sin(12 * x - 4)


Expand Down Expand Up @@ -249,17 +248,24 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
# with the optimised hyperparameters, and compare them to the predictions made using the
# posterior with the default hyperparameters:


# %%
def plot_ribbon(ax, x, dist, color):
mean = dist.mean()
std = dist.stddev()
ax.plot(x, mean, label="Predictive mean", color=color)
ax.fill_between(x.squeeze(), mean - 2 * std, mean + 2 * std, alpha=0.2, label="Two sigma", color=color)
ax.fill_between(
x.squeeze(),
mean - 2 * std,
mean + 2 * std,
alpha=0.2,
label="Two sigma",
color=color,
)
ax.plot(x, mean - 2 * std, linestyle="--", linewidth=1, color=color)
ax.plot(x, mean + 2 * std, linestyle="--", linewidth=1, color=color)



# %%
opt_latent_dist = opt_posterior.predict(test_x, train_data=D)
opt_predictive_dist = opt_posterior.likelihood(opt_latent_dist)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def one_step(state, rng_key):
samples = []

for i in range(num_adapt, num_samples + num_adapt, thin_factor):
sample = jtu.tree_map(lambda samples: samples[i], states.position)
sample = jtu.tree_map(lambda samples: samples[i], states.position) # noqa: B023
sample = sample.constrain()
latent_dist = sample.predict(xtest, train_data=D)
predictive_dist = sample.likelihood(latent_dist)
Expand Down
1 change: 0 additions & 1 deletion docs/examples/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from jaxtyping import install_import_hook
import matplotlib as mpl
import matplotlib.pyplot as plt
import optax as ox
from docs.examples.utils import clean_legend

with install_import_hook("gpjax", "beartype.beartype"):
Expand Down
1 change: 0 additions & 1 deletion docs/examples/uncollapsed_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@

config.update("jax_enable_x64", True)

from jax import jit
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook
Expand Down
2 changes: 1 addition & 1 deletion docs/scripts/sharp_bits_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,4 @@
np.log(0.05)

# %%
x
print(x)

0 comments on commit d52530e

Please sign in to comment.