Skip to content

Commit

Permalink
Add test for freeze_dims_and_data in JAX backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 authored and twiecki committed May 30, 2024
1 parent 40fb76c commit c495d8a
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions tests/sampling/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from pymc import ImputationWarning
from pymc.distributions.multivariate import DirichletMultinomial, PosDefMatrix
from pymc.model.transform.optimization import freeze_dims_and_data
from pymc.sampling.jax import (
_get_batched_jittered_initial_points,
_get_log_likelihood,
Expand Down Expand Up @@ -514,6 +515,24 @@ def test_convergence_warnings(caplog, nuts_sampler):


def test_dirichlet_multinomial():
"""Test we can draw from a DM in the JAX backend if the shape is constant."""
dm = DirichletMultinomial.dist(n=5, a=np.eye(3) * 1e6 + 0.01)
dm_draws = pm.draw(dm, mode="JAX")
np.testing.assert_equal(dm_draws, np.eye(3) * 5)


def test_dirichlet_multinomial_dims():
"""Test we can draw from a DM with a shape defined by dims in the JAX backend,
after freezing those dims.
"""
with pm.Model(coords={"trial": range(3), "item": range(3)}) as m:
dm = DirichletMultinomial("dm", n=5, a=np.eye(3) * 1e6 + 0.01, dims=("trial", "item"))

# JAX does not allow us to JIT a function with dynamic shape
with pytest.raises(TypeError):
pm.draw(dm, mode="JAX")

# Should be fine after freezing the dims that specify the shape
frozen_dm = freeze_dims_and_data(m)["dm"]
dm_draws = pm.draw(frozen_dm, mode="JAX")
np.testing.assert_equal(dm_draws, np.eye(3) * 5)

0 comments on commit c495d8a

Please sign in to comment.