Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add blas_cores argument to pm.sample #7318

Merged
merged 3 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies:
- networkx
- scipy>=1.4.1
- typing-extensions>=3.7.4
- threadpoolctl>=3.1.0
# Extra dependencies for dev, testing and docs build
- ipython>=7.16
- jax
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies:
- rich>=13.7.1
- scipy>=1.4.1
- typing-extensions>=3.7.4
- threadpoolctl>=3.1.0
# Extra dependencies for docs build
- ipython>=7.16
- jax
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-jax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies:
- python-graphviz
- networkx
- rich>=13.7.1
- threadpoolctl>=3.1.0
# JAX is only compatible with Scipy 1.13.0 from >=0.4.26, but the respective version of
# JAXlib is still not on conda: https://github.com/conda-forge/jaxlib-feedstock/pull/243
- scipy>=1.4.1,<1.13.0
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies:
- rich>=13.7.1
- scipy>=1.4.1
- typing-extensions>=3.7.4
- threadpoolctl>=3.1.0
# Extra dependencies for testing
- ipython>=7.16
- pre-commit>=2.8.0
Expand Down
1 change: 1 addition & 0 deletions conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies:
- rich>=13.7.1
- scipy>=1.4.1
- typing-extensions>=3.7.4
- threadpoolctl>=3.1.0
# Extra dependencies for dev, testing and docs build
- ipython>=7.16
- myst-nb
Expand Down
1 change: 1 addition & 0 deletions conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies:
- rich>=13.7.1
- scipy>=1.4.1
- typing-extensions>=3.7.4
- threadpoolctl>=3.1.0
# Extra dependencies for testing
- ipython>=7.16
- pre-commit>=2.8.0
Expand Down
106 changes: 75 additions & 31 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

"""Functions for MCMC sampling."""

import contextlib
import logging
import pickle
import sys
import time
import warnings

from collections.abc import Iterator, Mapping, Sequence
from collections.abc import Callable, Iterator, Mapping, Sequence
from typing import (
Any,
Literal,
Expand All @@ -37,6 +38,7 @@
from rich.console import Console
from rich.progress import Progress
from rich.theme import Theme
from threadpoolctl import threadpool_limits
from typing_extensions import Protocol

import pymc as pm
Expand Down Expand Up @@ -396,6 +398,7 @@ def sample(
nuts_sampler_kwargs: dict[str, Any] | None = None,
callback=None,
mp_ctx=None,
blas_cores: int | None | Literal["auto"] = "auto",
**kwargs,
) -> InferenceData: ...

Expand Down Expand Up @@ -427,6 +430,7 @@ def sample(
callback=None,
mp_ctx=None,
model: Model | None = None,
blas_cores: int | None | Literal["auto"] = "auto",
**kwargs,
) -> MultiTrace: ...

Expand Down Expand Up @@ -456,6 +460,7 @@ def sample(
nuts_sampler_kwargs: dict[str, Any] | None = None,
callback=None,
mp_ctx=None,
blas_cores: int | None | Literal["auto"] = "auto",
model: Model | None = None,
**kwargs,
) -> InferenceData | MultiTrace:
Expand Down Expand Up @@ -499,6 +504,13 @@ def sample(
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
This requires the chosen sampler to be installed.
All samplers, except "pymc", require the full model to be continuous.
blas_cores: int or "auto" or None, default = "auto"
The total number of threads blas and openmp functions should use during sampling.
Setting it to "auto" will ensure that the total number of active blas threads is the
same as the `cores` argument. If set to an integer, the sampler will try to use that total
number of blas threads. If `blas_cores` is not divisible by `cores`, it might get rounded
down. If set to None, this will keep the default behavior of whatever blas implementation
is used at runtime.
initvals : optional, dict, array of dict
Dict or list of dicts with initial value strategies to use instead of the defaults from
`Model.initial_values`. The keys should be names of transformed random variables.
Expand Down Expand Up @@ -644,6 +656,28 @@ def sample(
if chains is None:
chains = max(2, cores)

if blas_cores == "auto":
blas_cores = cores

cores = min(cores, chains)

num_blas_cores_per_chain: int | None
joined_blas_limiter: Callable[[], Any]

if blas_cores is None:
joined_blas_limiter = contextlib.nullcontext
num_blas_cores_per_chain = None
elif isinstance(blas_cores, int):

def joined_blas_limiter():
return threadpool_limits(limits=blas_cores)

num_blas_cores_per_chain = blas_cores // cores
else:
raise ValueError(
f"Invalid argument `blas_cores`, must be int, 'auto' or None: {blas_cores}"
)

if random_seed == -1:
random_seed = None
random_seed_list = _get_seeds_per_chain(random_seed, chains)
Expand Down Expand Up @@ -685,21 +719,22 @@ def sample(
raise ValueError(
"Model can not be sampled with NUTS alone. Your model is probably not continuous."
)
return _sample_external_nuts(
sampler=nuts_sampler,
draws=draws,
tune=tune,
chains=chains,
target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
random_seed=random_seed,
initvals=initvals,
model=model,
var_names=var_names,
progressbar=progressbar,
idata_kwargs=idata_kwargs,
nuts_sampler_kwargs=nuts_sampler_kwargs,
**kwargs,
)
with joined_blas_limiter():
return _sample_external_nuts(
sampler=nuts_sampler,
draws=draws,
tune=tune,
chains=chains,
target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
random_seed=random_seed,
initvals=initvals,
model=model,
var_names=var_names,
progressbar=progressbar,
idata_kwargs=idata_kwargs,
nuts_sampler_kwargs=nuts_sampler_kwargs,
**kwargs,
)

if isinstance(step, list):
step = CompoundStep(step)
Expand All @@ -708,18 +743,19 @@ def sample(
nuts_kwargs = kwargs.pop("nuts")
[kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
_log.info("Auto-assigning NUTS sampler...")
initial_points, step = init_nuts(
init=init,
chains=chains,
n_init=n_init,
model=model,
random_seed=random_seed_list,
progressbar=progressbar,
jitter_max_retries=jitter_max_retries,
tune=tune,
initvals=initvals,
**kwargs,
)
with joined_blas_limiter():
initial_points, step = init_nuts(
init=init,
chains=chains,
n_init=n_init,
model=model,
random_seed=random_seed_list,
progressbar=progressbar,
jitter_max_retries=jitter_max_retries,
tune=tune,
initvals=initvals,
**kwargs,
)

if initial_points is None:
# Time to draw/evaluate numeric start points for each chain.
Expand Down Expand Up @@ -756,7 +792,8 @@ def sample(
)

sample_args = {
"draws": draws + tune, # FIXME: Why is tune added to draws?
# draws is now the total number of draws, including tuning
"draws": draws + tune,
"step": step,
"start": initial_points,
"traces": traces,
Expand All @@ -772,6 +809,7 @@ def sample(
}
parallel_args = {
"mp_ctx": mp_ctx,
"blas_cores": num_blas_cores_per_chain,
}

sample_args.update(kwargs)
Expand Down Expand Up @@ -817,11 +855,15 @@ def sample(
if has_population_samplers:
_log.info(f"Population sampling ({chains} chains)")
_print_step_hierarchy(step)
_sample_population(initial_points=initial_points, parallelize=cores > 1, **sample_args)
with joined_blas_limiter():
_sample_population(
initial_points=initial_points, parallelize=cores > 1, **sample_args
)
else:
_log.info(f"Sequential sampling ({chains} chains in 1 job)")
_print_step_hierarchy(step)
_sample_many(**sample_args)
with joined_blas_limiter():
_sample_many(**sample_args)

t_sampling = time.time() - t_start

Expand Down Expand Up @@ -1139,6 +1181,7 @@ def _mp_sample(
traces: Sequence[IBaseTrace],
model: Model | None = None,
callback: SamplingIteratorCallback | None = None,
blas_cores: int | None = None,
mp_ctx=None,
**kwargs,
) -> None:
Expand Down Expand Up @@ -1190,6 +1233,7 @@ def _mp_sample(
step_method=step,
progressbar=progressbar,
progressbar_theme=progressbar_theme,
blas_cores=blas_cores,
mp_ctx=mp_ctx,
)
try:
Expand Down
40 changes: 24 additions & 16 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from rich.console import Console
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.theme import Theme
from threadpoolctl import threadpool_limits

from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
Expand Down Expand Up @@ -93,6 +94,7 @@ def __init__(
draws: int,
tune: int,
seed,
blas_cores,
):
self._msg_pipe = msg_pipe
self._step_method = step_method
Expand All @@ -102,6 +104,7 @@ def __init__(
self._at_seed = seed + 1
self._draws = draws
self._tune = tune
self._blas_cores = blas_cores

def _unpickle_step_method(self):
unpickle_error = (
Expand All @@ -116,22 +119,23 @@ def _unpickle_step_method(self):
raise ValueError(unpickle_error)

def run(self):
try:
# We do not create this in __init__, as pickling this
# would destroy the shared memory.
self._unpickle_step_method()
self._point = self._make_numpy_refs()
self._start_loop()
except KeyboardInterrupt:
pass
except BaseException as e:
e = ExceptionWithTraceback(e, e.__traceback__)
# Send is not blocking so we have to force a wait for the abort
# message
self._msg_pipe.send(("error", e))
self._wait_for_abortion()
finally:
self._msg_pipe.close()
with threadpool_limits(limits=self._blas_cores):
try:
# We do not create this in __init__, as pickling this
# would destroy the shared memory.
self._unpickle_step_method()
self._point = self._make_numpy_refs()
self._start_loop()
except KeyboardInterrupt:
pass
except BaseException as e:
e = ExceptionWithTraceback(e, e.__traceback__)
# Send is not blocking so we have to force a wait for the abort
# message
self._msg_pipe.send(("error", e))
self._wait_for_abortion()
finally:
self._msg_pipe.close()

def _wait_for_abortion(self):
while True:
Expand Down Expand Up @@ -208,6 +212,7 @@ def __init__(
chain: int,
seed,
start: dict[str, np.ndarray],
blas_cores,
mp_ctx,
):
self.chain = chain
Expand Down Expand Up @@ -256,6 +261,7 @@ def __init__(
draws,
tune,
seed,
blas_cores,
),
)
self._process.start()
Expand Down Expand Up @@ -378,6 +384,7 @@ def __init__(
step_method,
progressbar: bool = True,
progressbar_theme: Theme | None = default_progress_theme,
blas_cores: int | None = None,
mp_ctx=None,
):
if any(len(arg) != chains for arg in [seeds, start_points]):
Expand Down Expand Up @@ -411,6 +418,7 @@ def __init__(
chain,
seed,
start,
blas_cores,
mp_ctx,
)
for chain, seed, start in zip(range(chains), seeds, start_points)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ pandas>=0.24.0
pytensor>=2.20,<2.21
rich>=13.7.1
scipy>=1.4.1
threadpoolctl>=3.1.0,<4.0.0
typing-extensions>=3.7.4
8 changes: 8 additions & 0 deletions tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,14 @@ def test_empty_model():
error.match("any free variables")


def test_blas_cores():
with pm.Model():
pm.Normal("a")
pm.sample(blas_cores="auto", tune=10, cores=2, draws=10)
pm.sample(blas_cores=None, tune=10, cores=2, draws=10)
pm.sample(blas_cores=2, tune=10, cores=2, draws=10)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved


def test_partial_trace_with_trace_unsupported():
with pm.Model() as model:
a = pm.Normal("a", mu=0, sigma=1)
Expand Down
2 changes: 2 additions & 0 deletions tests/sampling/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def test_explicit_sample(mp_start_method):
mp_ctx=ctx,
start={"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))},
step_method_pickled=step_method_pickled,
blas_cores=None,
)
proc.start()
while True:
Expand Down Expand Up @@ -193,6 +194,7 @@ def test_iterator():
start_points=[start] * 3,
step_method=step,
progressbar=False,
blas_cores=None,
)
with sampler:
for draw in sampler:
Expand Down
Loading