Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding meta-information for MeasurableOps #7076

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions pymc/logprob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
import abc

from collections.abc import Sequence
from enum import Enum, auto
from functools import singledispatch
from typing import Union

from pytensor.graph.op import Op
from pytensor.graph.utils import MetaType
Expand Down Expand Up @@ -131,14 +133,33 @@ def _icdf_helper(rv, value, **kwargs):
return rv_icdf


class MeasureType(Enum):
Discrete = auto()
Continuous = auto()
Mixed = auto()


class MeasurableVariable(abc.ABC):
"""A variable that can be assigned a measure/log-probability"""

def __init__(
self,
*args,
ndim_supp: Union[int, tuple],
supp_axes: tuple,
measure_type: Union[MeasureType, tuple],
**kwargs,
):
self.ndim_supp = ndim_supp
self.supp_axes = supp_axes
self.measure_type = measure_type
super().__init__(*args, **kwargs)


MeasurableVariable.register(RandomVariable)


class MeasurableElemwise(Elemwise):
class MeasurableElemwise(MeasurableVariable, Elemwise):
"""Base class for Measurable Elemwise variables"""

valid_scalar_types: tuple[MetaType, ...] = ()
Expand All @@ -152,4 +173,36 @@ def __init__(self, scalar_op, *args, **kwargs):
super().__init__(scalar_op, *args, **kwargs)


MeasurableVariable.register(MeasurableElemwise)
def get_measure_type_info(
base_var,
):
from pymc.logprob.utils import DiracDelta

if not isinstance(base_var, MeasurableVariable):
base_op = base_var.owner.op
index = base_var.owner.outputs.index(base_var)
else:
base_op = base_var
if not isinstance(base_op, MeasurableVariable):
raise TypeError("base_op must be a RandomVariable or MeasurableVariable")

if isinstance(base_op, DiracDelta):
ndim_supp = 0
supp_axes = ()
measure_type = MeasureType.Discrete
return ndim_supp, supp_axes, measure_type

if isinstance(base_op, RandomVariable):
ndim_supp = base_op.ndim_supp
supp_axes = tuple(range(-ndim_supp, 0))
measure_type = (
MeasureType.Continuous if base_op.dtype.startswith("float") else MeasureType.Discrete
)
return base_op.ndim_supp, supp_axes, measure_type
else:
# We'll need this for operators like scan and IfElse
if isinstance(base_op.ndim_supp, tuple):
if len(base_var.owner.outputs) != len(base_op.ndim_supp):
raise NotImplementedError("length of outputs and meta-properties is different")
Comment on lines +205 to +206
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too restrictive for Scan, which can have recurrent outputs that are not measurable variables. Let's just remove it for now?

return base_op.ndim_supp[index], base_op.supp_axes, base_op.measure_type
return base_op.ndim_supp, base_op.supp_axes, base_op.measure_type
19 changes: 17 additions & 2 deletions pymc/logprob/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@

from pymc.logprob.abstract import (
MeasurableElemwise,
MeasureType,
_logcdf_helper,
_logprob,
_logprob_helper,
get_measure_type_info,
)
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import check_potential_measurability
Expand Down Expand Up @@ -81,7 +83,14 @@ def find_measurable_comparisons(
elif isinstance(node_scalar_op, LE):
node_scalar_op = GE()

compared_op = MeasurableComparison(node_scalar_op)
ndim_supp, supp_axes, _ = get_measure_type_info(measurable_var)

compared_op = MeasurableComparison(
scalar_op=node_scalar_op,
ndim_supp=ndim_supp,
supp_axes=supp_axes,
measure_type=MeasureType.Discrete,
)
compared_rv = compared_op.make_node(measurable_var, const).default_output()
return [compared_rv]

Expand Down Expand Up @@ -148,7 +157,13 @@ def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> Optional[list[
return None

node_scalar_op = node.op.scalar_op
bitwise_op = MeasurableBitwise(node_scalar_op)
ndim_supp, supp_axis, measure_type = get_measure_type_info(base_var)
bitwise_op = MeasurableBitwise(
scalar_op=node_scalar_op,
ndim_supp=ndim_supp,
supp_axes=supp_axis,
measure_type=MeasureType.Discrete,
)
bitwise_rv = bitwise_op.make_node(base_var).default_output()
return [bitwise_rv]

Expand Down
26 changes: 21 additions & 5 deletions pymc/logprob/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,13 @@
from pytensor.tensor.math import ceil, clip, floor, round_half_to_even
from pytensor.tensor.variable import TensorConstant

from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob
from pymc.logprob.abstract import (
MeasurableElemwise,
MeasureType,
_logcdf,
_logprob,
get_measure_type_info,
)
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import CheckParameterValue

Expand All @@ -59,9 +65,6 @@ class MeasurableClip(MeasurableElemwise):
valid_scalar_types = (Clip,)


measurable_clip = MeasurableClip(scalar_clip)


@node_rewriter(tracks=[clip])
def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]:
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)
Expand All @@ -81,6 +84,15 @@ def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[list[Te
lower_bound = lower_bound if (lower_bound is not base_var) else pt.constant(-np.inf)
upper_bound = upper_bound if (upper_bound is not base_var) else pt.constant(np.inf)

ndim_supp, supp_axes, measure_type = get_measure_type_info(base_var)

if measure_type == MeasureType.Continuous:
measure_type = MeasureType.Mixed

measurable_clip = MeasurableClip(
scalar_clip, ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type
)

clipped_rv = measurable_clip.make_node(base_var, lower_bound, upper_bound).outputs[0]
return [clipped_rv]

Expand Down Expand Up @@ -167,7 +179,11 @@ def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[lis
return None

[base_var] = node.inputs
rounded_op = MeasurableRound(node.op.scalar_op)
ndim_supp, supp_axis, _ = get_measure_type_info(base_var)
measure_type = MeasureType.Discrete
rounded_op = MeasurableRound(
node.op.scalar_op, ndim_supp=ndim_supp, supp_axes=supp_axis, measure_type=measure_type
)
rounded_rv = rounded_op.make_node(base_var).default_output()
rounded_rv.name = node.outputs[0].name
return [rounded_rv]
Expand Down
32 changes: 21 additions & 11 deletions pymc/logprob/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,20 @@
from pytensor.tensor import TensorVariable
from pytensor.tensor.shape import SpecifyShape

from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
from pymc.logprob.abstract import (
MeasurableVariable,
_logprob,
_logprob_helper,
get_measure_type_info,
)
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import replace_rvs_by_values


class MeasurableSpecifyShape(SpecifyShape):
class MeasurableSpecifyShape(MeasurableVariable, SpecifyShape):
"""A placeholder used to specify a log-likelihood for a specify-shape sub-graph."""


MeasurableVariable.register(MeasurableSpecifyShape)


@_logprob.register(MeasurableSpecifyShape)
def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs):
(value,) = values
Expand Down Expand Up @@ -86,7 +88,11 @@ def find_measurable_specify_shapes(fgraph, node) -> Optional[list[TensorVariable
):
return None # pragma: no cover

new_op = MeasurableSpecifyShape()
ndim_supp, supp_axes, measure_type = get_measure_type_info(base_rv)

new_op = MeasurableSpecifyShape(
ndim_supp=ndim_supp, supp_axes=supp_axes, measure_type=measure_type
)
new_rv = new_op.make_node(base_rv, *shape).default_output()

return [new_rv]
Expand All @@ -100,13 +106,10 @@ def find_measurable_specify_shapes(fgraph, node) -> Optional[list[TensorVariable
)


class MeasurableCheckAndRaise(CheckAndRaise):
class MeasurableCheckAndRaise(MeasurableVariable, CheckAndRaise):
"""A placeholder used to specify a log-likelihood for an assert sub-graph."""


MeasurableVariable.register(MeasurableCheckAndRaise)


@_logprob.register(MeasurableCheckAndRaise)
def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs):
(value,) = values
Expand All @@ -133,7 +136,14 @@ def find_measurable_check_and_raise(fgraph, node) -> Optional[list[TensorVariabl
return None

op = node.op
new_op = MeasurableCheckAndRaise(exc_type=op.exc_type, msg=op.msg)
ndim_supp, supp_axis, d_type = get_measure_type_info(base_rv)
new_op = MeasurableCheckAndRaise(
exc_type=op.exc_type,
msg=op.msg,
ndim_supp=ndim_supp,
supp_axes=supp_axis,
measure_type=d_type,
)
new_rv = new_op.make_node(base_rv, *conds).default_output()

return [new_rv]
Expand Down
21 changes: 15 additions & 6 deletions pymc/logprob/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,19 @@
from pytensor.tensor import TensorVariable
from pytensor.tensor.extra_ops import CumOp

from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
from pymc.logprob.abstract import (
MeasurableVariable,
_logprob,
_logprob_helper,
get_measure_type_info,
)
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db


class MeasurableCumsum(CumOp):
class MeasurableCumsum(MeasurableVariable, CumOp):
"""A placeholder used to specify a log-likelihood for a cumsum sub-graph."""


MeasurableVariable.register(MeasurableCumsum)


@_logprob.register(MeasurableCumsum)
def logprob_cumsum(op, values, base_rv, **kwargs):
"""Compute the log-likelihood graph for a `Cumsum`."""
Expand Down Expand Up @@ -101,7 +103,14 @@ def find_measurable_cumsums(fgraph, node) -> Optional[list[TensorVariable]]:
if not rv_map_feature.request_measurable(node.inputs):
return None

new_op = MeasurableCumsum(axis=node.op.axis or 0, mode="add")
ndim_supp, supp_axes, measure_type = get_measure_type_info(base_rv)
new_op = MeasurableCumsum(
axis=node.op.axis or 0,
mode="add",
ndim_supp=ndim_supp,
supp_axes=supp_axes,
measure_type=measure_type,
)
new_rv = new_op.make_node(base_rv).default_output()

return [new_rv]
Expand Down
Loading
Loading