Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug when freezing_rv_and_dims after a model transformation #7296

Merged
merged 2 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions pymc/model/fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,18 @@ def fgraph_from_model(
return fgraph, memo


def model_from_fgraph(fgraph: FunctionGraph) -> Model:
def model_from_fgraph(fgraph: FunctionGraph, mutate_fgraph: bool = False) -> Model:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every time you call the function you use mutate_fgraph = True, should that be the default instead? Or are the instances in this PR special cases (and there are many others)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to ask for consent

"""Convert FunctionGraph to PyMC model.

This requires nodes to be properly tagged with `ModelVar` dummy Ops.
Parameters
----------
fgraph: FunctionGraph
fgraph representation of a PyMC model, with dummy `ModelVar` Ops.
See `fgraph_from_model` for more details.

See: fgraph_from_model
mutate_fgraph: bool, default False
Whether the function is allowed to modify the fgraph (and it's variables) in place.
This is useful if these are not needed anymore after the model is created.
"""

def first_non_model_var(var):
Expand All @@ -296,11 +302,21 @@ def first_non_model_var(var):
model = Model()
if model.parent is not None:
raise RuntimeError("model_to_fgraph cannot be called inside a PyMC model context")
model._coords = getattr(fgraph, "_coords", {})
model._dim_lengths = getattr(fgraph, "_dim_lengths", {})

_coords = getattr(fgraph, "_coords", {})
_dim_lengths = getattr(fgraph, "_dim_lengths", {})

if not mutate_fgraph:
fgraph, memo = fgraph.clone_get_equiv(check_integrity=False, attach_feature=False)
# Shared dim lengths are not extracted from the fgraph representation,
# so we need to update after we clone the fgraph
# TODO: Consider representing/extracting them from the fgraph!
_dim_lengths = {k: memo.get(v, v) for k, v in _dim_lengths.items()}

model._coords = _coords
model._dim_lengths = _dim_lengths

# Replace dummy `ModelVar` Ops by the underlying variables,
fgraph = fgraph.clone()
model_dummy_vars = [
model_node.outputs[0]
for model_node in fgraph.toposort()
Expand Down Expand Up @@ -376,7 +392,7 @@ def clone_model(model: Model) -> Model:
z = pm.Deterministic("z", clone_x + 1)

"""
return model_from_fgraph(fgraph_from_model(model)[0])
return model_from_fgraph(fgraph_from_model(model)[0], mutate_fgraph=True)


def extract_dims(var) -> tuple:
Expand Down
2 changes: 1 addition & 1 deletion pymc/model/transform/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def prune_vars_detached_from_observed(model: Model) -> Model:
}
for node_to_remove in nodes_to_remove:
fgraph.remove_node(node_to_remove)
return model_from_fgraph(fgraph)
return model_from_fgraph(fgraph, mutate_fgraph=True)


def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> list[Variable]:
Expand Down
6 changes: 3 additions & 3 deletions pymc/model/transform/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def observe(

toposort_replace(fgraph, tuple(replacements.items()))

return model_from_fgraph(fgraph)
return model_from_fgraph(fgraph, mutate_fgraph=True)


def do(
Expand Down Expand Up @@ -215,7 +215,7 @@ def do(
# Replace variables by interventions
toposort_replace(fgraph, tuple(replacements.items()))

model = model_from_fgraph(fgraph)
model = model_from_fgraph(fgraph, mutate_fgraph=True)
if prune_vars:
return prune_vars_detached_from_observed(model)
return model
Expand Down Expand Up @@ -302,7 +302,7 @@ def change_value_transforms(
replacements[dummy_rv] = new_dummy_rv

toposort_replace(fgraph, tuple(replacements.items()))
return model_from_fgraph(fgraph)
return model_from_fgraph(fgraph, mutate_fgraph=True)


def remove_value_transforms(
Expand Down
2 changes: 1 addition & 1 deletion pymc/model/transform/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def freeze_dims_and_data(
replacements[old_value] = new_value
fg.replace_all(tuple(replacements.items()), import_missing=True)

return model_from_fgraph(fg)
return model_from_fgraph(fg, mutate_fgraph=True)


__all__ = ("freeze_dims_and_data",)
7 changes: 6 additions & 1 deletion tests/model/test_fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pytest

from pytensor import config, shared
from pytensor.graph import Constant, FunctionGraph, node_rewriter
from pytensor.graph import Constant, FunctionGraph, graph_inputs, node_rewriter
from pytensor.graph.rewriting.basic import in2out
from pytensor.tensor.exceptions import NotScalarConstantError

Expand Down Expand Up @@ -164,6 +164,11 @@ def test_data(inline_views):
np.testing.assert_allclose(pm.draw(m_new["mu"]), [100.0, 200.0])
np.testing.assert_allclose(pm.draw(m_old["mu"]), [0.0, 1.0, 2.0], atol=1e-6)

# Check model dim_lengths contains the exact variables used in the graph of RVs
m_new_size_param = m_new["obs"].owner.inputs[1]
[m_new_dim_len] = graph_inputs([m_new_size_param])
assert m_new.dim_lengths["test_dim"] is m_new_dim_len


@config.change_flags(floatX="float64") # Avoid downcasting Ops in the graph
def test_shared_variable():
Expand Down
14 changes: 13 additions & 1 deletion tests/model/transform/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pytensor.compile import SharedVariable
from pytensor.graph import Constant

from pymc import Deterministic
from pymc import Deterministic, do
from pymc.data import Data
from pymc.distributions import HalfNormal, Normal
from pymc.model import Model
Expand Down Expand Up @@ -132,3 +132,15 @@ def test_freeze_dims_and_data_subset():
assert isinstance(new_m.dim_lengths["dim2"], SharedVariable)
assert isinstance(new_m["data1"], SharedVariable)
assert isinstance(new_m["data2"], Constant) and np.all(new_m["data2"].data == [1, 2, 3, 4, 5])


def test_freeze_dim_after_do_intervention():
with Model(coords={"test_dim": range(5)}) as m:
mu = Data("mu", [0, 1, 2, 3, 4], dims="test_dim")
x = Normal("x", mu=mu, dims="test_dim")

do_m = do(m, {mu: mu * 100})
assert do_m["x"].type.shape == (None,)

frozen_do_m = freeze_dims_and_data(do_m)
assert frozen_do_m["x"].type.shape == (5,)
Loading