Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perhaps finally a decent LBFGS? #426

Merged
merged 4 commits into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,11 @@
# With a posterior defined, we can now optimise the model's hyperparameters.

# %%
opt_posterior, training_history = gpx.fit(
opt_posterior, training_history = gpx.fit_scipy(
model=posterior,
objective=gpx.objectives.ConjugateMLL(negative=True),
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=1000,
key=key
)
)

# %% [markdown]
#
Expand Down
54 changes: 28 additions & 26 deletions gpjax/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@
Union,
)
import jax
from jax import (
jit,
value_and_grad,
)
from jax._src.random import _check_prng_key
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
import jax.random as jr
import jaxopt
import optax as ox
import scipy

from gpjax.base import Module
from gpjax.dataset import Dataset
Expand All @@ -42,10 +47,6 @@
ModuleModel = TypeVar("ModuleModel", bound=Module)


class FailedScipyFitError(Exception):
"""Raised a model fit using Scipy fails"""


def fit( # noqa: PLR0913
*,
model: ModuleModel,
Expand Down Expand Up @@ -214,30 +215,31 @@ def fit_scipy( # noqa: PLR0913
model = model.unconstrain()

# Unconstrained space loss function with stop-gradient rule for non-trainable params.
def loss(model: Module, data: Dataset) -> ScalarFloat:
def loss(model: Module) -> ScalarFloat:
model = model.stop_gradient()
return objective(model.constrain(), data)

solver = jaxopt.ScipyMinimize(
fun=loss,
maxiter=max_iters,
return objective(model.constrain(), train_data)

# convert to numpy for interface with scipy
x0, scipy_to_jnp = ravel_pytree(model)

@jit
def scipy_wrapper(x0):
value, grads = value_and_grad(loss)(scipy_to_jnp(jnp.array(x0)))
scipy_grads = ravel_pytree(grads)[0]
return value, scipy_grads

history = [scipy_wrapper(x0)[0]]
result = scipy.optimize.minimize(
fun=scipy_wrapper,
x0=x0,
jac=True,
callback=lambda X: history.append(scipy_wrapper(X)[0]),
options={"maxiter": max_iters, "disp": verbose},
)
history = jnp.array(history)

initial_loss = solver.fun(model, train_data)
model, result = solver.run(model, data=train_data)
history = jnp.array([initial_loss, result.fun_val])

if verbose:
print(f"Initial loss is {initial_loss}")
if result.success:
print("Optimization was successful")
else:
raise FailedScipyFitError(
"Optimization failed, try increasing max_iters or using a different optimiser."
)
print(f"Final loss is {result.fun_val} after {result.num_fun_eval} iterations")

# Constrained space.
# convert back to pytree and reconstrain
model = scipy_to_jnp(result.x)
model = model.constrain()
return model, history

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ jax = ">=0.4.10"
jaxlib = ">=0.4.10"
orbax-checkpoint = ">=0.2.3"
cola-ml = "^0.0.5"
jaxopt = "^0.8.2"

[tool.poetry.group.test.dependencies]
pytest = "^7.2.2"
Expand Down
10 changes: 5 additions & 5 deletions tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
import optax as ox
import pytest
import scipy
import tensorflow_probability.substrates.jax.bijectors as tfb

from gpjax.base import (
Expand All @@ -33,7 +34,6 @@
)
from gpjax.dataset import Dataset
from gpjax.fit import (
FailedScipyFitError,
fit,
fit_scipy,
get_batch,
Expand Down Expand Up @@ -116,7 +116,7 @@ def step(self, model: LinearModel, train_data: Dataset) -> float:
)

# Ensure we return a history of the correct length
assert len(hist) == 2
assert len(hist) > 2

# Ensure we return a model of the same class
assert isinstance(trained_model, LinearModel)
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_gaussian_process_regression(n_data: int, verbose: bool) -> None:
assert isinstance(trained_model_bfgs, ConjugatePosterior)

# Ensure we return a history_bfgs of the correct length
assert len(history_bfgs) == 2
assert len(history_bfgs) > 2

# Ensure we reduce the loss
assert mll(trained_model_bfgs, D) < mll(posterior, D)
Expand All @@ -206,7 +206,7 @@ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]:
# Define loss function:
mll = ConjugateMLL(negative=True)

with pytest.raises(FailedScipyFitError):
with pytest.raises(scipy.optimize.OptimizeWarning):
fit_scipy(
model=posterior,
objective=mll,
Expand All @@ -220,7 +220,7 @@ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]:
posterior = prior * likelihood
mll = ConjugateMLL(negative=True)

with pytest.raises(FailedScipyFitError):
with pytest.raises(scipy.optimize.OptimizeWarning):
fit_scipy(
model=posterior,
objective=mll,
Expand Down