-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement specialized transformed logp dispatch #7188
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,10 +38,15 @@ | |
|
||
from collections.abc import Sequence | ||
from functools import singledispatch | ||
from typing import Union | ||
|
||
import multipledispatch | ||
import pytensor.tensor as pt | ||
|
||
from pytensor.gradient import jacobian | ||
from pytensor.graph.op import Op | ||
from pytensor.graph.utils import MetaType | ||
from pytensor.tensor import TensorVariable | ||
from pytensor.tensor import TensorVariable, Variable | ||
from pytensor.tensor.elemwise import Elemwise | ||
from pytensor.tensor.random.op import RandomVariable | ||
|
||
|
@@ -153,3 +158,52 @@ def __init__(self, scalar_op, *args, **kwargs): | |
|
||
|
||
MeasurableVariable.register(MeasurableElemwise) | ||
|
||
|
||
class Transform(abc.ABC): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved to abstract, without any changes |
||
ndim_supp = None | ||
|
||
@abc.abstractmethod | ||
def forward(self, value: TensorVariable, *inputs: Variable) -> TensorVariable: | ||
"""Apply the transformation.""" | ||
|
||
@abc.abstractmethod | ||
def backward( | ||
self, value: TensorVariable, *inputs: Variable | ||
) -> Union[TensorVariable, tuple[TensorVariable, ...]]: | ||
"""Invert the transformation. Multiple values may be returned when the | ||
transformation is not 1-to-1""" | ||
|
||
def log_jac_det(self, value: TensorVariable, *inputs) -> TensorVariable: | ||
"""Construct the log of the absolute value of the Jacobian determinant.""" | ||
if self.ndim_supp not in (0, 1): | ||
raise NotImplementedError( | ||
f"RVTransform default log_jac_det only implemented for ndim_supp in (0, 1), got {self.ndim_supp=}" | ||
) | ||
if self.ndim_supp == 0: | ||
jac = pt.reshape(pt.grad(pt.sum(self.backward(value, *inputs)), [value]), value.shape) | ||
return pt.log(pt.abs(jac)) | ||
else: | ||
phi_inv = self.backward(value, *inputs) | ||
return pt.log(pt.abs(pt.nlinalg.det(pt.atleast_2d(jacobian(phi_inv, [value])[0])))) | ||
|
||
def __str__(self): | ||
return f"{self.__class__.__name__}" | ||
|
||
|
||
@multipledispatch.dispatch(Op, Transform) | ||
def _transformed_logprob( | ||
op: Op, | ||
transform: Transform, | ||
unconstrained_value: TensorVariable, | ||
rv_inputs: Sequence[TensorVariable], | ||
): | ||
"""Create a graph for the log-density/mass of a transformed ``RandomVariable``. | ||
|
||
This function dispatches on the type of ``op``, which should be a subclass | ||
of ``RandomVariable`` and ``transform``, which should be a subclass of ``Transform``. | ||
|
||
""" | ||
raise NotImplementedError( | ||
f"Transformed logprob method not implemented for {op} with transform {transform}" | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,12 +23,12 @@ | |
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.graph.replace import clone_replace | ||
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter | ||
from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB | ||
from pytensor.scan.op import Scan | ||
from pytensor.tensor.variable import TensorVariable | ||
|
||
from pymc.logprob.abstract import MeasurableVariable, _logprob | ||
from pymc.logprob.abstract import MeasurableVariable, Transform, _logprob, _transformed_logprob | ||
from pymc.logprob.rewriting import PreserveRVMappings, cleanup_ir_rewrites_db | ||
from pymc.logprob.transforms import Transform | ||
|
||
|
||
class TransformedValue(Op): | ||
|
@@ -97,7 +97,26 @@ def transformed_value_logprob(op, values, *rv_outs, use_jacobian=True, **kwargs) | |
This is introduced by the `TransformValuesRewrite` | ||
""" | ||
rv_op = rv_outs[0].owner.op | ||
transforms = op.transforms | ||
rv_inputs = rv_outs[0].owner.inputs | ||
|
||
if use_jacobian and len(values) == 1 and len(transforms) == 1: | ||
# Check if there's a specialized transform logp implemented | ||
[value] = values | ||
assert isinstance(value.owner.op, TransformedValue) | ||
unconstrained_value = value.owner.inputs[1] | ||
[transform] = transforms | ||
try: | ||
return _transformed_logprob( | ||
rv_op, | ||
transform, | ||
unconstrained_value=unconstrained_value, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be nice if we could also pass in the transformed value. That way we can avoid computing it twice in a graph if the logp can use that value too. I guess pytensor might get rid of that duplication anyway, but I don't know how reliable that is if the transformation is doing something more complicated. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can pass it too. You raise a point. Maybe transforms should be encapsulated in an OpFromGraph so that we can easily reverse symbolically and not worry whether they will be simplified in their raw form or not. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually we have the constrained value as well, its |
||
rv_inputs=rv_inputs, | ||
**kwargs, | ||
) | ||
except NotImplementedError: | ||
pass | ||
|
||
logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs) | ||
|
||
if not isinstance(logprobs, Sequence): | ||
|
@@ -112,8 +131,8 @@ def transformed_value_logprob(op, values, *rv_outs, use_jacobian=True, **kwargs) | |
continue | ||
|
||
assert isinstance(value.owner.op, TransformedValue) | ||
original_forward_value = value.owner.inputs[1] | ||
log_jac_det = transform.log_jac_det(original_forward_value, *rv_inputs).copy() | ||
unconstrained_value = value.owner.inputs[1] | ||
log_jac_det = transform.log_jac_det(unconstrained_value, *rv_inputs).copy() | ||
# The jacobian determinant has less dims than the logp | ||
# when a multivariate transform (like Simplex or Ordered) is applied to univariate distributions. | ||
# In this case we have to reduce the last logp dimensions, as they are no longer independent | ||
|
@@ -299,6 +318,17 @@ def transform_scan_values(fgraph: FunctionGraph, node: Apply) -> Optional[list[A | |
return transformed_rv_node.outputs | ||
|
||
|
||
transform_values_rewrites_db = SequenceDB() | ||
transform_values_rewrites_db.name = "transform_values_rewrites_db" | ||
|
||
transform_values_rewrites_db.register( | ||
"transform_values", in2out(transform_values, ignore_newtrees=True), "basic" | ||
) | ||
transform_values_rewrites_db.register( | ||
"transform_scan_values", in2out(transform_scan_values, ignore_newtrees=True), "basic" | ||
) | ||
|
||
|
||
class TransformValuesMapping(Feature): | ||
r"""A `Feature` that maintains a map between value variables and their transforms.""" | ||
|
||
|
@@ -315,9 +345,6 @@ def on_attach(self, fgraph): | |
class TransformValuesRewrite(GraphRewriter): | ||
r"""Transforms value variables according to a map.""" | ||
|
||
transform_rewrite = in2out(transform_values, ignore_newtrees=True) | ||
scan_transform_rewrite = in2out(transform_scan_values, ignore_newtrees=True) | ||
|
||
def __init__( | ||
self, | ||
values_to_transforms: dict[TensorVariable, Union[Transform, None]], | ||
|
@@ -340,8 +367,8 @@ def add_requirements(self, fgraph): | |
fgraph.attach_feature(values_transforms_feature) | ||
|
||
def apply(self, fgraph: FunctionGraph): | ||
self.transform_rewrite.rewrite(fgraph) | ||
self.scan_transform_rewrite.rewrite(fgraph) | ||
query = RewriteDatabaseQuery(include=["basic"]) | ||
transform_values_rewrites_db.query(query).rewrite(fgraph) | ||
|
||
|
||
@node_rewriter([TransformedValue]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this check is necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was a sanity check. A user could have defined an invalid transform manually. I don't really care either way