Skip to content

Commit

Permalink
Merge pull request #154 from JaxGaussianProcesses/incorporate_jax_lin…
Browse files Browse the repository at this point in the history
…ear_operator

Incorporate JaxLinOp with GPJax
  • Loading branch information
thomaspinder authored Nov 29, 2022
2 parents 23e30a8 + b41f7e4 commit dbaa5ce
Show file tree
Hide file tree
Showing 12 changed files with 604 additions and 863 deletions.
3 changes: 1 addition & 2 deletions examples/natgrads.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import jax.random as jr
import matplotlib.pyplot as plt
import optax as ox
from jax import jit, lax
from jax.config import config

import gpjax as gpx
Expand Down Expand Up @@ -97,7 +96,7 @@
n_iters=5000,
batch_size=256,
key=jr.PRNGKey(42),
moment_optim=ox.sgd(0.1),
moment_optim=ox.sgd(0.01),
hyper_optim=ox.adam(1e-3),
)

Expand Down
Loading

0 comments on commit dbaa5ce

Please sign in to comment.