Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add var_names argument to sample #7206

Merged
merged 9 commits into from
Mar 25, 2024

Conversation

fonnesbeck
Copy link
Member

@fonnesbeck fonnesbeck commented Mar 21, 2024

Description

Allow for filtering of variables included in sampled trace via an optional var_names argument, similar to what is done for plotting.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7206.org.readthedocs.build/en/7206/

@fonnesbeck
Copy link
Member Author

Just getting started here. Testing to see if this is the right approach.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @fonnesbeck

Left a small comment. Also I think the jax samplers kind of allow this functionality, so would only need to forward to the sample_external_nuts code path?

pymc/sampling/mcmc.py Show resolved Hide resolved
@ricardoV94
Copy link
Member

Apologies, I didn't see it was in draft :)

Copy link

codecov bot commented Mar 21, 2024

Codecov Report

Attention: Patch coverage is 50.00000% with 6 lines in your changes are missing coverage. Please review.

Project coverage is 90.29%. Comparing base (aa679f3) to head (13a5d31).
Report is 14 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7206      +/-   ##
==========================================
- Coverage   92.29%   90.29%   -2.01%     
==========================================
  Files         100      100              
  Lines       16875    16896      +21     
==========================================
- Hits        15575    15256     -319     
- Misses       1300     1640     +340     
Files Coverage Δ
pymc/backends/__init__.py 92.10% <100.00%> (+0.21%) ⬆️
pymc/sampling/mcmc.py 85.96% <66.66%> (-2.04%) ⬇️
pymc/sampling/jax.py 0.00% <0.00%> (-94.10%) ⬇️

... and 6 files with indirect coverage changes

@fonnesbeck
Copy link
Member Author

fonnesbeck commented Mar 21, 2024

No worries, feedback welcome at all stages! (earlier the better, in fact)

@fonnesbeck
Copy link
Member Author

Should probably enforce that all stochastic variables be included.

@ricardoV94
Copy link
Member

Should probably enforce that all stochastic variables be included.

Maybe it's fine not to?

@fonnesbeck
Copy link
Member Author

The numpyro sampler does not appear to do the right thing with the passed var names.

@ricardoV94
Copy link
Member

The numpyro sampler does not appear to do the right thing with the passed var names.

What does it do? Checking the source code, it looks like it should if you pass just the strings?

pymc/sampling/mcmc.py Outdated Show resolved Hide resolved
Better docstring

Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
@fonnesbeck
Copy link
Member Author

Seems like at some point the names don't get converted vars.

@fonnesbeck
Copy link
Member Author

(ok, fixed)

@fonnesbeck fonnesbeck marked this pull request as ready for review March 22, 2024 16:21
@fonnesbeck
Copy link
Member Author

Not sure how code coverage drops when I've added two tests.

@ricardoV94
Copy link
Member

Not sure how code coverage drops when I've added two tests.

Cov is flaky, not always up to date or comparing with the right commit

@@ -348,6 +349,7 @@ def _sample_external_nuts(
random_seed=random_seed,
initvals=initvals,
model=model,
var_names=var_names,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a warning about var_names not beeing used by nutpie like we have for some other arguments above?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, @aseyboldt how hard/reasonable is it to support this in nutpie?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it could be filtered in nutpie.sample._trace_to_arviz?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we want to filter during sampling already since RAM is usually the issue, not disk-space?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but there are no obvious hooks into the nutpie compiled model. It would require some changes on the nutpie side, by the looks of it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My top comment was to add a warning like these:

pymc/pymc/sampling/mcmc.py

Lines 283 to 294 in abe7bc9

if initvals is not None:
warnings.warn(
"`initvals` are currently not passed to nutpie sampler. "
"Use `init_mean` kwarg following nutpie specification instead.",
UserWarning,
)
if idata_kwargs is not None:
warnings.warn(
"`idata_kwargs` are currently ignored by the nutpie sampler",
UserWarning,
)

Not to try to monkey-patch nutpie from the outside

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once/if nutpie has similar functionality we can forward it from pymc?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go with the warning for now, and create an issue on nutpie for a solution.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be too hard. Nutpie uses a numba function to compute all values that should appear in the trace (including the deterministics and transformed values). We should be able to just export a subset (code is around here: https://github.com/pymc-devs/nutpie/blob/main/python/nutpie/compile_pymc.py#L387)

pymc/sampling/jax.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 merged commit 4c3ec36 into pymc-devs:main Mar 25, 2024
23 of 24 checks passed
@ricardoV94
Copy link
Member

Thanks @fonnesbeck !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Allow storing subset of variable in pm.sample via a var_names kwarg
3 participants