Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement specialized transformed logp dispatch #7188

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
to_tuple,
)
from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform
from pymc.logprob.abstract import _logprob
from pymc.logprob.abstract import _logprob, _transformed_logprob
from pymc.math import kron_diag, kron_dot
from pymc.pytensorf import intX
from pymc.util import check_dist_not_registered
Expand Down Expand Up @@ -2818,8 +2818,7 @@ def zerosumnormal_support_point(op, rv, *rv_inputs):

@_default_transform.register(ZeroSumNormalRV)
def zerosum_default_transform(op, rv):
n_zerosum_axes = tuple(np.arange(-op.ndim_supp, 0))
return ZeroSumTransform(n_zerosum_axes)
return ZeroSumTransform(n_zerosum_axes=op.ndim_supp)


@_logprob.register(ZeroSumNormalRV)
Expand All @@ -2845,3 +2844,12 @@ def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs):
)

return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")


@_transformed_logprob.register(ZeroSumNormalRV, ZeroSumTransform)
def transformed_zerosumnormal_logp(op, transform, unconstrained_value, rv_inputs):
_, sigma, _ = rv_inputs
zerosum_axes = transform.zerosum_axes
if len(zerosum_axes) != op.ndim_supp:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this check is necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a sanity check. A user could have defined an invalid transform manually. I don't really care either way

raise NotImplementedError
return pm.logp(Normal.dist(0, sigma), unconstrained_value).sum(zerosum_axes)
11 changes: 6 additions & 5 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,17 @@
from pytensor.graph import Op
from pytensor.tensor import TensorVariable

from pymc.logprob.abstract import Transform
from pymc.logprob.transforms import (
ChainedTransform,
CircularTransform,
IntervalTransform,
LogOddsTransform,
LogTransform,
SimplexTransform,
Transform,
)

__all__ = [
"Transform",
"simplex",
"logodds",
"Interval",
Expand Down Expand Up @@ -277,8 +276,10 @@

__props__ = ("zerosum_axes",)

def __init__(self, zerosum_axes):
self.zerosum_axes = tuple(int(axis) for axis in zerosum_axes)
def __init__(self, n_zerosum_axes: int):
if not n_zerosum_axes > 0:
raise ValueError("Transform is only valid for n_zerosum_axes > 0")

Check warning on line 281 in pymc/distributions/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/transforms.py#L281

Added line #L281 was not covered by tests
self.zerosum_axes = tuple(range(-n_zerosum_axes, 0))

@staticmethod
def extend_axis(array, axis):
Expand Down Expand Up @@ -314,7 +315,7 @@
return value

def log_jac_det(self, value, *rv_inputs):
return pt.constant(0.0)
return pt.zeros(value.shape[: -len(self.zerosum_axes)])

Check warning on line 318 in pymc/distributions/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/transforms.py#L318

Added line #L318 was not covered by tests


log_exp_m1 = LogExpM1()
Expand Down
2 changes: 1 addition & 1 deletion pymc/initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.transforms import Transform
from pymc.logprob.abstract import Transform
from pymc.pytensorf import compile_pymc, find_rng_nodes, replace_rng_nodes, reseed_rngs
from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name

Expand Down
56 changes: 55 additions & 1 deletion pymc/logprob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,15 @@

from collections.abc import Sequence
from functools import singledispatch
from typing import Union

import multipledispatch
import pytensor.tensor as pt

from pytensor.gradient import jacobian
from pytensor.graph.op import Op
from pytensor.graph.utils import MetaType
from pytensor.tensor import TensorVariable
from pytensor.tensor import TensorVariable, Variable
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable

Expand Down Expand Up @@ -153,3 +158,52 @@ def __init__(self, scalar_op, *args, **kwargs):


MeasurableVariable.register(MeasurableElemwise)


class Transform(abc.ABC):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to abstract, without any changes

ndim_supp = None

@abc.abstractmethod
def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable:
"""Apply the transformation."""

@abc.abstractmethod
def backward(
self, value: TensorVariable, *inputs: Variable
) -> Union[TensorVariable, tuple[TensorVariable, ...]]:
"""Invert the transformation. Multiple values may be returned when the
transformation is not 1-to-1"""

def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable:
"""Construct the log of the absolute value of the Jacobian determinant."""
if self.ndim_supp not in (0, 1):
raise NotImplementedError(
f"RVTransform default log_jac_det only implemented for ndim_supp in (0, 1), got {self.ndim_supp=}"
)
if self.ndim_supp == 0:
jac = pt.reshape(pt.grad(pt.sum(self.backward(value, *inputs)), [value]), value.shape)
return pt.log(pt.abs(jac))
else:
phi_inv = self.backward(value, *inputs)
return pt.log(pt.abs(pt.nlinalg.det(pt.atleast_2d(jacobian(phi_inv, [value])[0]))))

def __str__(self):
return f"{self.__class__.__name__}"


@multipledispatch.dispatch(Op, Transform)
def _transformed_logprob(
op: Op,
transform: Transform,
unconstrained_value: TensorVariable,
rv_inputs: Sequence[TensorVariable],
):
"""Create a graph for the log-density/mass of a transformed ``RandomVariable``.

This function dispatches on the type of ``op``, which should be a subclass
of ``RandomVariable`` and ``transform``, which should be a subclass of ``Transform``.

"""
raise NotImplementedError(
f"Transformed logprob method not implemented for {op} with transform {transform}"
)
2 changes: 1 addition & 1 deletion pymc/logprob/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@

from pymc.logprob.abstract import (
MeasurableVariable,
Transform,
_icdf_helper,
_logcdf_helper,
_logprob,
_logprob_helper,
)
from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph
from pymc.logprob.transform_value import TransformValuesRewrite
from pymc.logprob.transforms import Transform
from pymc.logprob.utils import rvs_in_graph
from pymc.pytensorf import replace_vars_in_graphs

Expand Down
45 changes: 36 additions & 9 deletions pymc/logprob/transform_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB
from pytensor.scan.op import Scan
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.abstract import MeasurableVariable, _logprob
from pymc.logprob.abstract import MeasurableVariable, Transform, _logprob, _transformed_logprob
from pymc.logprob.rewriting import PreserveRVMappings, cleanup_ir_rewrites_db
from pymc.logprob.transforms import Transform


class TransformedValue(Op):
Expand Down Expand Up @@ -97,7 +97,26 @@ def transformed_value_logprob(op, values, *rv_outs, use_jacobian=True, **kwargs)
This is introduced by the `TransformValuesRewrite`
"""
rv_op = rv_outs[0].owner.op
transforms = op.transforms
rv_inputs = rv_outs[0].owner.inputs

if use_jacobian and len(values) == 1 and len(transforms) == 1:
# Check if there's a specialized transform logp implemented
[value] = values
assert isinstance(value.owner.op, TransformedValue)
unconstrained_value = value.owner.inputs[1]
[transform] = transforms
try:
return _transformed_logprob(
rv_op,
transform,
unconstrained_value=unconstrained_value,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if we could also pass in the transformed value. That way we can avoid computing it twice in a graph if the logp can use that value too. I guess pytensor might get rid of that duplication anyway, but I don't know how reliable that is if the transformation is doing something more complicated.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can pass it too.

You raise a point. Maybe transforms should be encapsulated in an OpFromGraph so that we can easily reverse symbolically and not worry whether they will be simplified in their raw form or not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we have the constrained value as well, its value.owner.inputs[0]

rv_inputs=rv_inputs,
**kwargs,
)
except NotImplementedError:
pass

logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)

if not isinstance(logprobs, Sequence):
Expand All @@ -112,8 +131,8 @@ def transformed_value_logprob(op, values, *rv_outs, use_jacobian=True, **kwargs)
continue

assert isinstance(value.owner.op, TransformedValue)
original_forward_value = value.owner.inputs[1]
log_jac_det = transform.log_jac_det(original_forward_value, *rv_inputs).copy()
unconstrained_value = value.owner.inputs[1]
log_jac_det = transform.log_jac_det(unconstrained_value, *rv_inputs).copy()
# The jacobian determinant has less dims than the logp
# when a multivariate transform (like Simplex or Ordered) is applied to univariate distributions.
# In this case we have to reduce the last logp dimensions, as they are no longer independent
Expand Down Expand Up @@ -299,6 +318,17 @@ def transform_scan_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[A
return transformed_rv_node.outputs


transform_values_rewrites_db = SequenceDB()
transform_values_rewrites_db.name = "transform_values_rewrites_db"

transform_values_rewrites_db.register(
"transform_values", in2out(transform_values, ignore_newtrees=True), "basic"
)
transform_values_rewrites_db.register(
"transform_scan_values", in2out(transform_scan_values, ignore_newtrees=True), "basic"
)


class TransformValuesMapping(Feature):
r"""A `Feature` that maintains a map between value variables and their transforms."""

Expand All @@ -315,9 +345,6 @@ def on_attach(self, fgraph):
class TransformValuesRewrite(GraphRewriter):
r"""Transforms value variables according to a map."""

transform_rewrite = in2out(transform_values, ignore_newtrees=True)
scan_transform_rewrite = in2out(transform_scan_values, ignore_newtrees=True)

def __init__(
self,
values_to_transforms: dict[TensorVariable, Union[Transform, None]],
Expand All @@ -340,8 +367,8 @@ def add_requirements(self, fgraph):
fgraph.attach_feature(values_transforms_feature)

def apply(self, fgraph: FunctionGraph):
self.transform_rewrite.rewrite(fgraph)
self.scan_transform_rewrite.rewrite(fgraph)
query = RewriteDatabaseQuery(include=["basic"])
transform_values_rewrites_db.query(query).rewrite(fgraph)


@node_rewriter([TransformedValue])
Expand Down
36 changes: 2 additions & 34 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,13 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import abc

from typing import Callable, Optional, Union
from typing import Callable, Optional

import numpy as np
import pytensor.tensor as pt

from pytensor import scan
from pytensor.gradient import jacobian
from pytensor.graph.basic import Node, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
Expand Down Expand Up @@ -109,6 +107,7 @@
from pymc.logprob.abstract import (
MeasurableElemwise,
MeasurableVariable,
Transform,
_icdf,
_icdf_helper,
_logcdf,
Expand All @@ -124,37 +123,6 @@
)


class Transform(abc.ABC):
ndim_supp = None

@abc.abstractmethod
def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable:
"""Apply the transformation."""

@abc.abstractmethod
def backward(
self, value: TensorVariable, *inputs: Variable
) -> Union[TensorVariable, tuple[TensorVariable, ...]]:
"""Invert the transformation. Multiple values may be returned when the
transformation is not 1-to-1"""

def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable:
"""Construct the log of the absolute value of the Jacobian determinant."""
if self.ndim_supp not in (0, 1):
raise NotImplementedError(
f"RVTransform default log_jac_det only implemented for ndim_supp in (0, 1), got {self.ndim_supp=}"
)
if self.ndim_supp == 0:
jac = pt.reshape(pt.grad(pt.sum(self.backward(value, *inputs)), [value]), value.shape)
return pt.log(pt.abs(jac))
else:
phi_inv = self.backward(value, *inputs)
return pt.log(pt.abs(pt.nlinalg.det(pt.atleast_2d(jacobian(phi_inv, [value])[0]))))

def __str__(self):
return f"{self.__class__.__name__}"


class MeasurableTransform(MeasurableElemwise):
"""A placeholder used to specify a log-likelihood for a transformed measurable variable"""

Expand Down
6 changes: 1 addition & 5 deletions pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import typing
import warnings

from collections.abc import Container, Sequence
Expand All @@ -56,13 +55,10 @@
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.abstract import MeasurableVariable, _logprob
from pymc.logprob.abstract import MeasurableVariable, Transform, _logprob
from pymc.pytensorf import replace_vars_in_graphs
from pymc.util import makeiter

if typing.TYPE_CHECKING:
from pymc.logprob.transforms import Transform


def replace_rvs_by_values(
graphs: Sequence[TensorVariable],
Expand Down
2 changes: 1 addition & 1 deletion pymc/model/fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pytensor.scalar import Identity
from pytensor.tensor.elemwise import Elemwise

from pymc.logprob.transforms import Transform
from pymc.logprob.abstract import Transform
from pymc.model.core import Model
from pymc.pytensorf import StringType, find_rng_nodes, toposort_replace

Expand Down
2 changes: 1 addition & 1 deletion pymc/model/transform/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytensor.tensor import TensorVariable

from pymc import Model
from pymc.logprob.transforms import Transform
from pymc.logprob.abstract import Transform
from pymc.logprob.utils import rvs_in_graph
from pymc.model.fgraph import (
ModelDeterministic,
Expand Down
15 changes: 15 additions & 0 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,6 +1770,21 @@ def test_batched_sigma(self):
sigma=batch_test_sigma[None, :, None], n_zerosum_axes=2, support_shape=(3, 2)
)

def test_transformed_logprob(self):
with pm.Model() as m:
x = pm.ZeroSumNormal("x", sigma=np.pi, shape=(5, 3), n_zerosum_axes=1)
pytensor.dprint(m.compile_logp().f)

[transformed_logp] = m.logp(sum=False)

unconstrained_value = m.rvs_to_values[x]
transform = m.rvs_to_transforms[x]
constrained_value = transform.backward(unconstrained_value)
reference_logp = pm.logp(x, constrained_value)

test_dict = {unconstrained_value: pm.draw(transform.forward(x))}
np.testing.assert_allclose(transformed_logp.eval(test_dict), reference_logp.eval(test_dict))


class TestMvStudentTCov(BaseTestDistributionRandom):
def mvstudentt_rng_fn(self, size, nu, mu, scale, rng):
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import pymc as pm
import pymc.distributions.transforms as tr

from pymc.logprob.abstract import Transform
from pymc.logprob.basic import transformed_conditional_logp
from pymc.logprob.transforms import Transform
from pymc.pytensorf import floatX, jacobian
from pymc.testing import (
Circ,
Expand Down
Loading
Loading