Skip to content

Commit

Permalink
Fix join logp for multivariate RVs
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Apr 8, 2024
1 parent 6f90f83 commit 937e5fd
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 35 deletions.
6 changes: 4 additions & 2 deletions pymc/logprob/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,12 @@ def logprob_join(op, values, axis, *base_rvs, **kwargs):
# If the stacked variables depend on each other, we have to replace them by the respective values
logps = replace_rvs_by_values(logps, rvs_to_values=base_rvs_to_split_values)

base_vars_ndim_supp = split_values[0].ndim - logps[0].ndim
# Make axis positive and adjust for multivariate logp fewer dimensions to the right
axis = pt.switch(axis >= 0, axis, value.ndim + axis)
axis = pt.minimum(axis, logps[0].ndim - 1)
join_logprob = pt.concatenate(
[pt.atleast_1d(logp) for logp in logps],
axis=axis - base_vars_ndim_supp,
axis=axis,
)

return join_logprob
Expand Down
54 changes: 21 additions & 33 deletions tests/logprob/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,34 +269,23 @@ def test_measurable_join_univariate(size1, size2, axis, concatenate):


@pytest.mark.parametrize(
"size1, supp_size1, size2, supp_size2, axis, concatenate",
"size1, supp_size1, size2, supp_size2, axis, concatenate, logp_axis",
[
(None, 2, None, 2, 0, True),
(None, 2, None, 2, -1, True),
((5,), 2, (3,), 2, 0, True),
((5,), 2, (3,), 2, -2, True),
((2,), 5, (2,), 3, 1, True),
pytest.param(
(2,),
5,
(2,),
5,
0,
False,
marks=pytest.mark.xfail(reason="cannot measure dimshuffled multivariate RVs"),
),
pytest.param(
(2,),
5,
(2,),
5,
1,
False,
marks=pytest.mark.xfail(reason="cannot measure dimshuffled multivariate RVs"),
),
(None, 2, None, 2, 0, True, 0),
(None, 2, None, 2, -1, True, 0),
((5,), 2, (3,), 2, 0, True, 0),
((5,), 2, (3,), 2, -2, True, 0),
((2,), 5, (2,), 3, 1, True, 0),
((5, 6), 10, (5, 1), 10, 1, True, 1),
((5, 6), 10, (5, 1), 10, -2, True, 1),
((2,), 5, (2,), 5, 0, False, 0),
((2,), 5, (2,), 5, 1, False, 1),
((5, 6), 10, (5, 6), 10, 2, False, 2),
],
)
def test_measurable_join_multivariate(size1, supp_size1, size2, supp_size2, axis, concatenate):
def test_measurable_join_multivariate(
size1, supp_size1, size2, supp_size2, axis, concatenate, logp_axis
):
base1_rv = pt.random.multivariate_normal(
np.zeros(supp_size1), np.eye(supp_size1), size=size1, name="base1"
)
Expand All @@ -310,19 +299,18 @@ def test_measurable_join_multivariate(size1, supp_size1, size2, supp_size2, axis
base1_vv = base1_rv.clone()
base2_vv = base2_rv.clone()
y_vv = y_rv.clone()

y_logp = logp(y_rv, y_vv)
assert_no_rvs(y_logp)

base_logps = [
pt.atleast_1d(logp)
for logp in conditional_logp({base1_rv: base1_vv, base2_rv: base2_vv}).values()
]

if concatenate:
axis_norm = np.core.numeric.normalize_axis_index(axis, base1_rv.ndim)
base_logps = pt.concatenate(base_logps, axis=axis_norm - 1)
expected_logp = pt.concatenate(base_logps, axis=logp_axis)
else:
axis_norm = np.core.numeric.normalize_axis_index(axis, base1_rv.ndim + 1)
base_logps = pt.stack(base_logps, axis=axis_norm - 1)
y_logp = y_logp = logp(y_rv, y_vv)
assert_no_rvs(y_logp)
expected_logp = pt.stack(base_logps, axis=logp_axis)

base1_testval = base1_rv.eval()
base2_testval = base2_rv.eval()
Expand All @@ -331,7 +319,7 @@ def test_measurable_join_multivariate(size1, supp_size1, size2, supp_size2, axis
else:
y_testval = np.stack((base1_testval, base2_testval), axis=axis)
np.testing.assert_allclose(
base_logps.eval({base1_vv: base1_testval, base2_vv: base2_testval}),
expected_logp.eval({base1_vv: base1_testval, base2_vv: base2_testval}),
y_logp.eval({y_vv: y_testval}),
)

Expand Down

0 comments on commit 937e5fd

Please sign in to comment.