Skip to content

Commit

Permalink
Allow freezing only subset of data and dims
Browse files Browse the repository at this point in the history
*Also fix dim_lengths not being returned
  • Loading branch information
ricardoV94 committed Apr 30, 2024
1 parent 04b6881 commit 0ad689c
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 26 deletions.
13 changes: 10 additions & 3 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,14 @@ def set_dim(self, name: str, new_length: int, coord_values: Sequence | None = No
expected=new_length,
)
self._coords[name] = tuple(coord_values)
self.dim_lengths[name].set_value(new_length)
dim_length = self.dim_lengths[name]
if not isinstance(dim_length, SharedVariable):
raise TypeError(
f"The dim_length of `{name}` must be a `SharedVariable` "
"(created through `coords` to allow updating). "
f"The current type is: {type(dim_length)}"
)
dim_length.set_value(new_length)
return

def initial_point(self, random_seed: SeedSequenceSeed = None) -> dict[str, np.ndarray]:
Expand Down Expand Up @@ -1102,8 +1109,8 @@ def set_data(
shared_object = self[name]
if not isinstance(shared_object, SharedVariable):
raise TypeError(
f"The variable `{name}` must be a `SharedVariable`"
" (created through `pm.Data()` or `pm.Data(mutable=True)`) to allow updating. "
f"The variable `{name}` must be a `SharedVariable` "
"(created through `pm.Data()` to allow updating.) "
f"The current type is: {type(shared_object)}"
)

Expand Down
66 changes: 48 additions & 18 deletions pymc/model/transform/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence

from pytensor import clone_replace
from pytensor.compile import SharedVariable
from pytensor.graph import FunctionGraph
from pytensor.tensor import constant
from pytensor.tensor.sharedvar import TensorSharedVariable
from pytensor.tensor.variable import TensorConstant

from pymc import Model
from pymc.model.fgraph import ModelFreeRV, fgraph_from_model, model_from_fgraph


def freeze_dims_and_data(model: Model) -> Model:
def _constant_from_shared(shared: SharedVariable) -> TensorConstant:
assert isinstance(shared, TensorSharedVariable)
return constant(shared.get_value(), name=shared.name, dtype=shared.type.dtype)


def freeze_dims_and_data(
model: Model, dims: Sequence[str] | None = None, data: Sequence[str] | None = None
) -> Model:
"""Recreate a Model with fixed RV dimensions and Data values.
The dimensions of the pre-existing RVs will no longer follow changes to the coordinates.
Expand All @@ -30,41 +41,60 @@ def freeze_dims_and_data(model: Model) -> Model:
This transformation may allow more performant sampling, or compiling model functions to backends that
are more restrictive about dynamic shapes such as JAX.
Parameters
----------
model : Model
The model where to freeze dims and data.
dims : Sequence of str, optional
The dimensions to freeze.
If None, all dimensions are frozen. Pass an empty list to avoid freezing any dimension.
data : Sequence of str, optional
The data to freeze.
If None, all data are frozen. Pass an empty list to avoid freezing any data.
Returns
-------
Model
A new model with the specified dimensions and data frozen.
"""
fg, memo = fgraph_from_model(model)

if dims is None:
dims = tuple(model.dim_lengths.keys())
if data is None:
data = tuple(model.named_vars.keys())

# Replace mutable dim lengths and data by constants
frozen_vars = {
memo[dim_length]: constant(
dim_length.get_value(), name=dim_length.name, dtype=dim_length.type.dtype
)
for dim_length in model.dim_lengths.values()
frozen_replacements = {
memo[dim_length]: _constant_from_shared(dim_length)
for dim_length in (model.dim_lengths[dim_name] for dim_name in dims)
if isinstance(dim_length, SharedVariable)
}
frozen_vars |= {
memo[data_var].owner.inputs[0]: constant(
data_var.get_value(), name=data_var.name, dtype=data_var.type.dtype
)
for data_var in model.named_vars.values()
if isinstance(data_var, SharedVariable)
frozen_replacements |= {
memo[datum].owner.inputs[0]: _constant_from_shared(datum)
for datum in (model.named_vars[datum_name] for datum_name in data)
if isinstance(datum, SharedVariable)
}

old_outs, coords = fg.outputs, fg._coords # type: ignore
old_outs, old_coords, old_dim_lenghts = fg.outputs, fg._coords, fg._dim_lengths # type: ignore
# Rebuild strict will force the recreation of RV nodes with updated static types
new_outs = clone_replace(old_outs, replace=frozen_vars, rebuild_strict=False) # type: ignore
new_outs = clone_replace(old_outs, replace=frozen_replacements, rebuild_strict=False) # type: ignore
for old_out, new_out in zip(old_outs, new_outs):
new_out.name = old_out.name
fg = FunctionGraph(outputs=new_outs, clone=False)
fg._coords = coords # type: ignore
fg._coords = old_coords # type: ignore
fg._dim_lengths = { # type: ignore
dim: frozen_replacements.get(dim_length, dim_length)
for dim, dim_length in old_dim_lenghts.items()
}

# Recreate value variables from new RVs to propagate static types to logp graphs
replacements = {}
for node in fg.apply_nodes:
if not isinstance(node.op, ModelFreeRV):
continue
rv, old_value, *dims = node.inputs
if dims is None:
continue
rv, old_value, *_ = node.inputs
transform = node.op.transform
if transform is None:
new_value = rv.type()
Expand Down
93 changes: 88 additions & 5 deletions tests/model/transform/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest

from pytensor.compile import SharedVariable
from pytensor.graph import Constant

from pymc import Deterministic
from pymc.data import Data
from pymc.distributions import HalfNormal, Normal
from pymc.model import Model
from pymc.model.transform.optimization import freeze_dims_and_data


def test_freeze_existing_rv_dims_and_data():
def test_freeze_dims_and_data():
with Model(coords={"test_dim": range(5)}) as m:
std = Data("std", [1])
std = Data("test_data", [1])
x = HalfNormal("x", std, dims=("test_dim",))
y = Normal("y", shape=x.shape[0] + 1)

Expand All @@ -34,18 +39,96 @@ def test_freeze_existing_rv_dims_and_data():
assert y_logp.type.shape == (None,)

frozen_m = freeze_dims_and_data(m)
std, x, y = frozen_m["std"], frozen_m["x"], frozen_m["y"]
data, x, y = frozen_m["test_data"], frozen_m["x"], frozen_m["y"]
x_logp, y_logp = frozen_m.logp(sum=False)
assert isinstance(std, Constant)
assert isinstance(data, Constant)
assert x.type.shape == (5,)
assert y.type.shape == (6,)
assert x_logp.type.shape == (5,)
assert y_logp.type.shape == (6,)

# Test trying to update a frozen data or dim raises an informative error
with frozen_m:
with pytest.raises(TypeError, match="The variable `test_data` must be a `SharedVariable`"):
frozen_m.set_data("test_data", values=[2])
with pytest.raises(
TypeError, match="The dim_length of `test_dim` must be a `SharedVariable`"
):
frozen_m.set_dim("test_dim", new_length=6, coord_values=range(6))

# Test we can still update original model
with m:
m.set_data("test_data", values=[2])
m.set_dim("test_dim", new_length=6, coord_values=range(6))
assert m["test_data"].get_value() == [2]
assert m.dim_lengths["test_dim"].get_value() == 6

def test_freeze_rv_dims_nothing_to_change():

def test_freeze_dims_nothing_to_change():
with Model(coords={"test_dim": range(5)}) as m:
x = HalfNormal("x", shape=(5,))
y = Normal("y", shape=x.shape[0] + 1)

assert m.point_logps() == freeze_dims_and_data(m).point_logps()


def test_freeze_dims_and_data_subset():
with Model(coords={"dim1": range(3), "dim2": range(5)}) as m:
data1 = Data("data1", [1, 2, 3], dims="dim1")
data2 = Data("data2", [1, 2, 3, 4, 5], dims="dim2")
var1 = Normal("var1", dims="dim1")
var2 = Normal("var2", dims="dim2")
x = data1 * var1
y = data2 * var2
det = Deterministic("det", x[:, None] + y[None, :])

assert det.type.shape == (None, None)

new_m = freeze_dims_and_data(m, dims=["dim1"], data=[])
assert new_m["det"].type.shape == (3, None)
assert isinstance(new_m.dim_lengths["dim1"], Constant) and new_m.dim_lengths["dim1"].data == 3
assert isinstance(new_m.dim_lengths["dim2"], SharedVariable)
assert isinstance(new_m["data1"], SharedVariable)
assert isinstance(new_m["data2"], SharedVariable)

new_m = freeze_dims_and_data(m, dims=["dim2"], data=[])
assert new_m["det"].type.shape == (None, 5)
assert isinstance(new_m.dim_lengths["dim1"], SharedVariable)
assert isinstance(new_m.dim_lengths["dim2"], Constant) and new_m.dim_lengths["dim2"].data == 5
assert isinstance(new_m["data1"], SharedVariable)
assert isinstance(new_m["data2"], SharedVariable)

new_m = freeze_dims_and_data(m, dims=["dim1", "dim2"], data=[])
assert new_m["det"].type.shape == (3, 5)
assert isinstance(new_m.dim_lengths["dim1"], Constant) and new_m.dim_lengths["dim1"].data == 3
assert isinstance(new_m.dim_lengths["dim2"], Constant) and new_m.dim_lengths["dim2"].data == 5
assert isinstance(new_m["data1"], SharedVariable)
assert isinstance(new_m["data2"], SharedVariable)

new_m = freeze_dims_and_data(m, dims=[], data=["data1"])
assert new_m["det"].type.shape == (3, None)
assert isinstance(new_m.dim_lengths["dim1"], SharedVariable)
assert isinstance(new_m.dim_lengths["dim2"], SharedVariable)
assert isinstance(new_m["data1"], Constant) and np.all(new_m["data1"].data == [1, 2, 3])
assert isinstance(new_m["data2"], SharedVariable)

new_m = freeze_dims_and_data(m, dims=[], data=["data2"])
assert new_m["det"].type.shape == (None, 5)
assert isinstance(new_m.dim_lengths["dim1"], SharedVariable)
assert isinstance(new_m.dim_lengths["dim2"], SharedVariable)
assert isinstance(new_m["data1"], SharedVariable)
assert isinstance(new_m["data2"], Constant) and np.all(new_m["data2"].data == [1, 2, 3, 4, 5])

new_m = freeze_dims_and_data(m, dims=[], data=["data1", "data2"])
assert new_m["det"].type.shape == (3, 5)
assert isinstance(new_m.dim_lengths["dim1"], SharedVariable)
assert isinstance(new_m.dim_lengths["dim2"], SharedVariable)
assert isinstance(new_m["data1"], Constant) and np.all(new_m["data1"].data == [1, 2, 3])
assert isinstance(new_m["data2"], Constant) and np.all(new_m["data2"].data == [1, 2, 3, 4, 5])

new_m = freeze_dims_and_data(m, dims=["dim1"], data=["data2"])
assert new_m["det"].type.shape == (3, 5)
assert isinstance(new_m.dim_lengths["dim1"], Constant) and new_m.dim_lengths["dim1"].data == 3
assert isinstance(new_m.dim_lengths["dim2"], SharedVariable)
assert isinstance(new_m["data1"], SharedVariable)
assert isinstance(new_m["data2"], Constant) and np.all(new_m["data2"].data == [1, 2, 3, 4, 5])

0 comments on commit 0ad689c

Please sign in to comment.