Skip to content

Commit

Permalink
Add warning about future change in hessian sign
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt authored and ricardoV94 committed May 9, 2024
1 parent 82eae9a commit 3729614
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 17 deletions.
9 changes: 7 additions & 2 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@ def compile_d2logp(
self,
vars: Variable | Sequence[Variable] | None = None,
jacobian: bool = True,
negate_output=True,
**compile_kwargs,
) -> PointFunc:
"""Compiled log probability density hessian function.
Expand All @@ -670,7 +671,10 @@ def compile_d2logp(
jacobian : bool
Whether to include jacobian terms in logprob graph. Defaults to True.
"""
return self.compile_fn(self.d2logp(vars=vars, jacobian=jacobian), **compile_kwargs)
return self.model.compile_fn(
self.d2logp(vars=vars, jacobian=jacobian, negate_output=negate_output),
**compile_kwargs,
)

def logp(
self,
Expand Down Expand Up @@ -794,6 +798,7 @@ def d2logp(
self,
vars: Variable | Sequence[Variable] | None = None,
jacobian: bool = True,
negate_output=True,
) -> Variable:
"""Hessian of the models log-probability w.r.t. ``vars``.
Expand Down Expand Up @@ -827,7 +832,7 @@ def d2logp(

cost = self.logp(jacobian=jacobian)
cost = rewrite_pregrad(cost)
return hessian(cost, value_vars)
return hessian(cost, value_vars, negate_output=negate_output)

@property
def datalogp(self) -> Variable:
Expand Down
26 changes: 22 additions & 4 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,17 @@ def grad_ii(i, f, x):


@pytensor.config.change_flags(compute_test_value="ignore")
def hessian(f, vars=None):
return -jacobian(gradient(f, vars), vars)
def hessian(f, vars=None, negate_output=True):
res = jacobian(gradient(f, vars), vars)
if negate_output:
warnings.warn(
"hessian will stop negating the output in a future version of PyMC.\n"
"To suppress this warning set `negate_output=False`",
FutureWarning,
stacklevel=2,
)
res = -res
return res


@pytensor.config.change_flags(compute_test_value="ignore")
Expand All @@ -368,12 +377,21 @@ def hess_ii(i):


@pytensor.config.change_flags(compute_test_value="ignore")
def hessian_diag(f, vars=None):
def hessian_diag(f, vars=None, negate_output=True):
if vars is None:
vars = cont_inputs(f)

if vars:
return -pt.concatenate([hessian_diag1(f, v) for v in vars], axis=0)
res = pt.concatenate([hessian_diag1(f, v) for v in vars], axis=0)
if negate_output:
warnings.warn(
"hessian_diag will stop negating the output in a future version of PyMC.\n"
"To suppress this warning set `negate_output=False`",
FutureWarning,
stacklevel=2,
)
res = -res
return res
else:
return empty_gradient

Expand Down
2 changes: 1 addition & 1 deletion pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,7 +1461,7 @@ def init_nuts(
potential = quadpotential.QuadPotentialDiag(cov)
elif init == "map":
start = pm.find_MAP(include_transformed=True, seed=random_seed_list[0])
cov = pm.find_hessian(point=start)
cov = -pm.find_hessian(point=start, negate_output=False)
initial_points = [start] * chains
potential = quadpotential.QuadPotentialFull(cov)
elif init == "adapt_full":
Expand Down
10 changes: 5 additions & 5 deletions pymc/tuning/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def fixed_hessian(point, model=None):
return rval


def find_hessian(point, vars=None, model=None):
def find_hessian(point, vars=None, model=None, negate_output=True):
"""
Returns Hessian of logp at the point passed.
Expand All @@ -55,11 +55,11 @@ def find_hessian(point, vars=None, model=None):
Variables for which Hessian is to be calculated.
"""
model = modelcontext(model)
H = model.compile_d2logp(vars)
H = model.compile_d2logp(vars, negate_output=negate_output)
return H(Point(point, filter_model_vars=True, model=model))


def find_hessian_diag(point, vars=None, model=None):
def find_hessian_diag(point, vars=None, model=None, negate_output=True):
"""
Returns Hessian of logp at the point passed.
Expand All @@ -71,14 +71,14 @@ def find_hessian_diag(point, vars=None, model=None):
Variables for which Hessian is to be calculated.
"""
model = modelcontext(model)
H = model.compile_fn(hessian_diag(model.logp(), vars))
H = model.compile_fn(hessian_diag(model.logp(), vars, negate_output=negate_output))
return H(Point(point, model=model))


def guess_scaling(point, vars=None, model=None, scaling_bound=1e-8):
model = modelcontext(model)
try:
h = find_hessian_diag(point, vars, model=model)
h = -find_hessian_diag(point, vars, model=model, negate_output=False)
except NotImplementedError:
h = fixed_hessian(point, model=model)
return adjust_scaling(h, scaling_bound)
Expand Down
8 changes: 4 additions & 4 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,16 +1012,16 @@ def test_model_d2logp(jacobian):
test_vals = np.array([0.0, -1.0])
state = {"x": test_vals, "y_log__": test_vals}

expected_x_d2logp = expected_y_d2logp = np.eye(2)
expected_x_d2logp = expected_y_d2logp = -np.eye(2)

dlogps = m.compile_d2logp(jacobian=jacobian)(state)
dlogps = m.compile_d2logp(jacobian=jacobian, negate_output=False)(state)
assert np.all(np.isclose(dlogps[:2, :2], expected_x_d2logp))
assert np.all(np.isclose(dlogps[2:, 2:], expected_y_d2logp))

x_dlogp2 = m.compile_d2logp(vars=[x], jacobian=jacobian)(state)
x_dlogp2 = m.compile_d2logp(vars=[x], jacobian=jacobian, negate_output=False)(state)
assert np.all(np.isclose(x_dlogp2, expected_x_d2logp))

y_dlogp2 = m.compile_d2logp(vars=[y], jacobian=jacobian)(state)
y_dlogp2 = m.compile_d2logp(vars=[y], jacobian=jacobian, negate_output=False)(state)
assert np.all(np.isclose(y_dlogp2, expected_y_d2logp))


Expand Down
18 changes: 17 additions & 1 deletion tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pytensor import scan, shared
from pytensor.compile import UnusedInputError
from pytensor.compile.builders import OpFromGraph
from pytensor.graph.basic import Variable
from pytensor.graph.basic import Variable, equal_computations
from pytensor.tensor.random.basic import normal, uniform
from pytensor.tensor.random.var import RandomStateSharedVariable
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
Expand All @@ -43,6 +43,8 @@
constant_fold,
convert_observed_data,
extract_obs_data,
hessian,
hessian_diag,
replace_rng_nodes,
replace_vars_in_graphs,
reseed_rngs,
Expand Down Expand Up @@ -726,3 +728,17 @@ def test_replace_vars_in_graphs_nested_reference():
assert np.abs(x.eval()) < 1
# Confirm the original `y` variable is not changed in place
assert np.abs(y.eval()) < 1


@pytest.mark.filterwarnings("error")
@pytest.mark.parametrize("func", (hessian, hessian_diag))
def test_hessian_sign_change_warning(func):
x = pt.vector("x")
f = (x**2).sum()
with pytest.warns(
FutureWarning,
match="will stop negating the output",
):
res_neg = func(f, vars=[x])
res = func(f, vars=[x], negate_output=False)
assert equal_computations([res_neg], [-res])

0 comments on commit 3729614

Please sign in to comment.