Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smoke testing #307

Merged
merged 8 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: Integration Tests
on:
pull_request:
push:
branches:
- main

jobs:
test:
name: Run Integration Tests
runs-on: ubuntu-latest
strategy:
matrix:
# Select the Python versions to test against
os: ["ubuntu-latest", "macos-latest"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
fail-fast: true
steps:
- name: Check out the code
uses: actions/checkout@v3.5.2
with:
fetch-depth: 1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

# Install Poetry
- name: Install Poetry
uses: snok/install-poetry@v1.3.3
with:
version: 1.4.0

# Configure Poetry to use the virtual environment in the project
- name: Setup Poetry
run: |
poetry config virtualenvs.in-project true

# Install the dependencies
- name: Install Package
run: |
poetry install --all-extras --with docs

# Run the unit tests and build the coverage report
- name: Run Integration Tests
run: poetry run python tests/integration_tests.py
45 changes: 35 additions & 10 deletions docs/examples/collapsed_vi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# ---
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# custom_cell_magics: kql
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: gpjax_beartype
# language: python
# name: python3
# ---

# %% [markdown]
# # Sparse Gaussian Process Regression
#
Expand Down Expand Up @@ -71,12 +87,16 @@
fig, ax = plt.subplots()
ax.scatter(x, y, alpha=0.25, label="Observations", color=cols[0])
ax.plot(xtest, ytest, label="Latent function", linewidth=2, color=cols[1])
[
ax.axvline(x=z_i, alpha=0.3, linewidth=0.5, label="Inducing point", color=cols[2])
for z_i in z
]
ax.vlines(
x=z,
ymin=y.min(),
ymax=y.max(),
alpha=0.3,
linewidth=0.5,
label="Inducing point",
color=cols[2],
)
ax.legend(loc="best")
ax = clean_legend(ax)
plt.show()

# %% [markdown]
Expand Down Expand Up @@ -195,12 +215,17 @@
linewidth=0.5,
)

[
ax.axvline(x=z_i, alpha=0.3, linewidth=0.5, label="Inducing point", color=cols[2])
for z_i in inducing_points
]

ax.vlines(
x=inducing_points,
ymin=ytest.min(),
ymax=ytest.max(),
alpha=0.3,
linewidth=0.5,
label="Inducing point",
color=cols[2],
)
ax.legend()
ax = clean_legend(ax)
ax.set(xlabel=r"$x$", ylabel=r"$f(x)$")
plt.show()

Expand Down
44 changes: 27 additions & 17 deletions docs/examples/uncollapsed_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import matplotlib.pyplot as plt
import optax as ox
import tensorflow_probability.substrates.jax as tfp
from docs.examples.utils import clean_legend

with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
Expand Down Expand Up @@ -132,14 +131,18 @@
z = jnp.linspace(-5.0, 5.0, 50).reshape(-1, 1)

fig, ax = plt.subplots()
[
ax.axvline(x=z_i, color=cols[2], alpha=0.3, linewidth=1, label="Inducing point")
for z_i in z
]
ax.vlines(
z,
ymin=y.min(),
ymax=y.max(),
alpha=0.3,
linewidth=1,
label="Inducing point",
color=cols[2],
)
ax.scatter(x, y, alpha=0.2, color=cols[0], label="Observations")
ax.plot(xtest, f(xtest), color=cols[1], label="Latent function")
ax.legend()
ax = clean_legend(ax)
ax.set(xlabel=r"$x$", ylabel=r"$f(x)$")

# %% [markdown]
Expand Down Expand Up @@ -241,7 +244,6 @@
# advisable. This can be achieved by wrapping the function in `jax.jit()`.

# %%

negative_elbo = jit(negative_elbo)

# %% [markdown]
Expand Down Expand Up @@ -296,12 +298,16 @@
color=cols[1],
label="Two sigma",
)
[
ax.axvline(x=z_i, color=cols[2], alpha=0.3, linewidth=1, label="Inducing point")
for z_i in opt_posterior.inducing_inputs
]
ax.vlines(
opt_posterior.inducing_inputs,
ymin=y.min(),
ymax=y.max(),
alpha=0.3,
linewidth=1,
label="Inducing point",
color=cols[2],
)
ax.legend()
ax = clean_legend(ax)

# %% [markdown]
# ## Custom transformations
Expand Down Expand Up @@ -349,12 +355,16 @@
color=cols[1],
label="Two sigma",
)
[
ax.axvline(x=z_i, color=cols[2], alpha=0.3, linewidth=1, label="Inducing point")
for z_i in opt_posterior.inducing_inputs
]
ax.vlines(
opt_rep.inducing_inputs,
ymin=y.min(),
ymax=y.max(),
alpha=0.3,
linewidth=1,
label="Inducing point",
color=cols[2],
)
ax.legend()
ax = clean_legend(ax)

# %% [markdown]
# We can see that `Square` transformation is able to get relatively better fit
Expand Down
93 changes: 93 additions & 0 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from dataclasses import (
dataclass,
field,
)

from beartype.typing import (
Any,
Callable,
Dict,
)
import jax.numpy as jnp # noqa: F401
import jupytext

get_last = lambda x: x[-1]


@dataclass
class Result:
path: str
comparisons: field(default_factory=dict)
precision: int = 5

def __post_init__(self):
self.name: str = self.path.split("/")[-1].split(".")[0].replace("_", "-")

def _compare(
self,
observed_variables: Dict[str, Any],
variable_name: str,
true_value: float,
operation: Callable[[Any], Any],
):
try:
value = operation(observed_variables[variable_name])
assert true_value == value
except AssertionError as e:
print(e)

def test(self):
notebook = jupytext.read(self.path)
contents = ""
for c in notebook["cells"]:
if c["cell_type"] == "code":
if c["source"].startswith("%"):
pass
else:
contents += c["source"]
contents += "\n"

contents = contents.replace('plt.style.use("./gpjax.mplstyle")', "").replace(
"plt.show()", ""
)
lines = contents.split("\n")
contents = "\n".join([line for line in lines if not line.startswith("%")])

loc = {}
exec(contents, globals(), loc)
for k, v in self.comparisons.items():
truth, op = v
self._compare(
observed_variables=loc, variable_name=k, true_value=truth, operation=op
)


regression = Result(
path="docs/examples/regression.py",
comparisons={
"history": (55.07405622, get_last),
"predictive_mean": (36.24383416, jnp.sum),
"predictive_std": (197.04727051, jnp.sum),
},
)
regression.test()

sparse = Result(
path="docs/examples/collapsed_vi.py",
comparisons={
"history": (1924.7634809, get_last),
"predictive_mean": (-8.39869652, jnp.sum),
"predictive_std": (255.74838027, jnp.sum),
},
)
sparse.test()

stochastic = Result(
path="docs/examples/uncollapsed_vi.py",
comparisons={
"history": (-985.71726453, get_last),
"meanf": (-54.14787028, jnp.sum),
"sigma": (116.16651265, jnp.sum),
},
)
stochastic.test()