diff --git a/gpjax/variational_inference.py b/gpjax/variational_inference.py index 036c9ffe..d33755b3 100644 --- a/gpjax/variational_inference.py +++ b/gpjax/variational_inference.py @@ -153,7 +153,7 @@ def q_moments(x): log_prob = vmap(lambda f, y: link_function(params["likelihood"], f).log_prob(y)) # ≈ ∫[log(p(y|f(x))) q(f(x))] df(x) - expectation = gauss_hermite_quadrature(log_prob, mean, variance, y=y) + expectation = gauss_hermite_quadrature(log_prob, mean, jnp.sqrt(variance), y=y) return expectation