Skip to content

Commit

Permalink
Add more idata attributes for JAX samplers (#7360)
Browse files Browse the repository at this point in the history
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
  • Loading branch information
osyuksel and ricardoV94 authored Jun 17, 2024
1 parent bbd5739 commit 5bc6801
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ def sample_jax_nuts(

attrs = {
"sampling_time": (tic2 - tic1).total_seconds(),
"tuning_steps": tune,
}

coords, dims = coords_and_dims_for_inferencedata(model)
Expand All @@ -680,6 +681,7 @@ def sample_jax_nuts(
coords.update(idata_kwargs.pop("coords"))
if "dims" in idata_kwargs:
dims.update(idata_kwargs.pop("dims"))

# Use 'partial' to set default arguments before passing 'idata_kwargs'
to_trace = partial(
az.from_dict,
Expand All @@ -690,6 +692,7 @@ def sample_jax_nuts(
coords=coords,
dims=dims,
attrs=make_attrs(attrs, library=library),
posterior_attrs=make_attrs(attrs, library=library),
)
az_trace = to_trace(posterior=mcmc_samples, **idata_kwargs)

Expand Down
1 change: 1 addition & 0 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def _sample_external_nuts(
attrs = make_attrs(
{
"sampling_time": t_sample,
"tuning_steps": tune,
},
library=nutpie,
)
Expand Down
7 changes: 7 additions & 0 deletions tests/sampling/test_mcmc_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
idata1 = sample(**kwargs)
idata2 = sample(**kwargs)

reference_kwargs = kwargs.copy()
reference_kwargs["nuts_sampler"] = "pymc"
idata_reference = sample(**reference_kwargs)

warns = {
(warn.category, warn.message.args[0])
for warn in recwarn
Expand All @@ -64,8 +68,11 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
assert "L" in idata1.observed_data
assert idata1.posterior.chain.size == 2
assert idata1.posterior.draw.size == 500
assert idata1.posterior.tuning_steps == 500
np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x)

assert idata_reference.posterior.attrs.keys() == idata1.posterior.attrs.keys()


def test_step_args():
with Model() as model:
Expand Down

0 comments on commit 5bc6801

Please sign in to comment.