Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add blas_cores argument to pm.sample #7318

Merged
merged 3 commits into from
May 16, 2024
Merged

Conversation

aseyboldt
Copy link
Member

@aseyboldt aseyboldt commented May 15, 2024

Description

We currently do not configure blas in any way. This can lead to very bad behavior if we sample in several threads:
Many blas implementations default to using one worker thread per hardware thread in the machine. But if we sample in parallel with multiprocessing, each chain will use an independent thread pool, so we end up starting chains*hardware_chains worker threads. Combined with some spinnlocking that some blas implementations seem to do, this can lead to terrible performance.

This PR adds a blas_cores argument to pm.sample(), and then uses threadpoolctl to control how many worker threads we start.

If it is set to None, we don't do anything, and keep the current behavior of just using whatever the blas implementation uses as default. If set to auto (the default) use the cores argument to guess a decent number of blas worker threads. If it is set to an integer, we use that number of total blas worker.

See for instance here for a model that shows bad behavior without this PR.

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7318.org.readthedocs.build/en/7318/

@aseyboldt aseyboldt force-pushed the blas-cores branch 2 times, most recently from 51daa0c to 15cbbf4 Compare May 15, 2024 18:42
@@ -499,6 +504,13 @@ def sample(
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
This requires the chosen sampler to be installed.
All samplers, except "pymc", require the full model to be continuous.
blas_cores: int or "auto" or None, default = "auto"
The total number of threads blas and openmp functions should use during sampling. If set to None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain the default first?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1
Do you think we should already default to "auto", or first release something where the default is None so that this can be tested a bit more?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the new default makes much more sense. This often shows up in MvNormal models and it's very tricky for beginners to debug

Comment on lines 664 to 671
if cores < 1:
raise ValueError("`cores` must be larger or equal to one")

if chains < 1:
raise ValueError("`chains` must be larger or equal to one")

if blas_cores is not None and blas_cores < 1:
raise ValueError("`blas_cores` must be larger or equal to one")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove it for the sake of less code? I don't believe anybody was ever hurt by this and couldn't figure out the problem?

tests/sampling/test_mcmc.py Show resolved Hide resolved
@ricardoV94
Copy link
Member

pre commit failed

@aseyboldt aseyboldt merged commit a144c43 into pymc-devs:main May 16, 2024
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants