Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not consider dims without coords volatile if length has not changed #7381

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
)
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.random.var import RandomGeneratorSharedVariable
from pytensor.tensor.sharedvar import SharedVariable
from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable
from pytensor.tensor.variable import TensorConstant
from rich.console import Console
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.theme import Theme
Expand Down Expand Up @@ -73,6 +74,28 @@
_log = logging.getLogger(__name__)


def get_constant_coords(trace_coords: dict[str, np.ndarray], model: Model) -> set:
"""Get the set of coords that have remained constant between the trace and model"""
constant_coords = set()
for dim, coord in trace_coords.items():
current_coord = model.coords.get(dim, None)
current_length = model.dim_lengths.get(dim, None)
if isinstance(current_length, TensorSharedVariable):
current_length = current_length.get_value()
elif isinstance(current_length, TensorConstant):
current_length = current_length.data

Check warning on line 86 in pymc/sampling/forward.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/forward.py#L86

Added line #L86 was not covered by tests
if (
current_coord is not None
and len(coord) == len(current_coord)
and np.all(coord == current_coord)
) or (
# Coord was defined without values (only length)
current_coord is None and len(coord) == current_length
):
constant_coords.add(dim)
return constant_coords


def get_vars_in_point_list(trace, model):
"""Get the list of Variable instances in the model that have values stored in the trace."""
if not isinstance(trace, MultiTrace):
Expand Down Expand Up @@ -789,15 +812,7 @@
stacklevel=2,
)

constant_coords = set()
for dim, coord in trace_coords.items():
current_coord = model.coords.get(dim, None)
if (
current_coord is not None
and len(coord) == len(current_coord)
and np.all(coord == current_coord)
):
constant_coords.add(dim)
constant_coords = get_constant_coords(trace_coords, model)

if var_names is not None:
vars_ = [model[x] for x in var_names]
Expand Down
54 changes: 54 additions & 0 deletions tests/sampling/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from pymc.pytensorf import compile_pymc
from pymc.sampling.forward import (
compile_forward_sampling_function,
get_constant_coords,
get_vars_in_point_list,
observed_dependent_deterministics,
)
Expand Down Expand Up @@ -428,6 +429,45 @@ def test_mutable_coords_volatile(self):
"offsets",
}

def test_length_coords_volatile(self):
with pm.Model() as model:
model.add_coord("trial", length=3)
x = pm.Normal("x", dims="trial")
y = pm.Deterministic("y", x.mean())

# Same coord length -- `x` is not volatile
trace_same_len = az_from_dict(
posterior={"x": [[[np.pi] * 3]]},
coords={"trial": range(3)},
dims={"x": ["trial"]},
)
with model:
pp_same_len = pm.sample_posterior_predictive(
trace_same_len, var_names=["y"]
).posterior_predictive
assert pp_same_len["y"] == np.pi

# Coord length changed -- `x` is volatile
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to change the dim length in the pymc model and pass exactly the same idata to sample_posterior_predictive as was done in the first call.

Copy link
Member

@ricardoV94 ricardoV94 Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way to update an existing dim is with Model.set_dim()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, just got to this now. Implemented in 49bbf8c -- hope I understood you correctly

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup you got it!

trace_diff_len = az_from_dict(
posterior={"x": [[[np.pi] * 2]]},
coords={"trial": range(2)},
dims={"x": ["trial"]},
)
with model:
pp_diff_len = pm.sample_posterior_predictive(
trace_diff_len, var_names=["y"]
).posterior_predictive
assert pp_diff_len["y"] != np.pi

# Changing the dim length on the model itself
# -- `x` is volatile because trace has same len as original model
model.set_dim("trial", new_length=7)
with model:
pp_diff_len_model_set = pm.sample_posterior_predictive(
trace_same_len, var_names=["y"]
).posterior_predictive
assert pp_diff_len_model_set["y"] != np.pi


class TestSamplePPC:
def test_normal_scalar(self):
Expand Down Expand Up @@ -1670,6 +1710,20 @@ def test_Triangular(
assert prior["target"].shape == (prior_samples, *shape)


def test_get_constant_coords():
with pm.Model() as model:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a more integration-like test like?

Spit-balling, perhaps something like:

with pm.Model() as m:
  m.add_coord("trial", length=3)
  x = pm.Normal("x", shape="trial")
  y = pm.Deterministic("y", x.mean())

# pass dims as well  
idata = az.from_dict("x": [[np.pi, np.pi, np.pi]])

with m:
  pp1 = pm.sample_posterior_predictive(idata, var_names=["y"]).posterior_predictive

assert pp1["y"] == np.pi

with m:
  # change coord length 
  pp2 = pm.sample_posterior_predictive(idata, var_names=["y"]).posterior_predictive

assert pp2["y"] != np.pi

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this look ok? d40fffe

model.add_coord("length_coord", length=1)
model.add_coord("value_coord", values=(3,))

trace_coords_same = {"length_coord": np.array([0]), "value_coord": np.array([3])}
constant_coords_same = get_constant_coords(trace_coords_same, model)
assert constant_coords_same == {"length_coord", "value_coord"}

trace_coords_diff = {"length_coord": np.array([0, 1]), "value_coord": np.array([4])}
constant_coords_diff = get_constant_coords(trace_coords_diff, model)
assert constant_coords_diff == set()


def test_get_vars_in_point_list():
with pm.Model() as modelA:
pm.Normal("a", 0, 1)
Expand Down
Loading