Skip to content

Commit

Permalink
feat: Add blas_cores argument to pm.sample
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed May 16, 2024
1 parent 4e8e986 commit 0cebb66
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 47 deletions.
115 changes: 84 additions & 31 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

"""Functions for MCMC sampling."""

import contextlib
import logging
import pickle
import sys
import time
import warnings

from collections.abc import Iterator, Mapping, Sequence
from collections.abc import Callable, Iterator, Mapping, Sequence
from typing import (
Any,
Literal,
Expand All @@ -37,6 +38,7 @@
from rich.console import Console
from rich.progress import Progress
from rich.theme import Theme
from threadpoolctl import threadpool_limits
from typing_extensions import Protocol

import pymc as pm
Expand Down Expand Up @@ -396,6 +398,7 @@ def sample(
nuts_sampler_kwargs: dict[str, Any] | None = None,
callback=None,
mp_ctx=None,
blas_cores: int | None | Literal["auto"] = "auto",
**kwargs,
) -> InferenceData: ...

Expand Down Expand Up @@ -427,6 +430,7 @@ def sample(
callback=None,
mp_ctx=None,
model: Model | None = None,
blas_cores: int | None | Literal["auto"] = "auto",
**kwargs,
) -> MultiTrace: ...

Expand Down Expand Up @@ -456,6 +460,7 @@ def sample(
nuts_sampler_kwargs: dict[str, Any] | None = None,
callback=None,
mp_ctx=None,
blas_cores: int | None | Literal["auto"] = "auto",
model: Model | None = None,
**kwargs,
) -> InferenceData | MultiTrace:
Expand Down Expand Up @@ -499,6 +504,13 @@ def sample(
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
This requires the chosen sampler to be installed.
All samplers, except "pymc", require the full model to be continuous.
blas_cores: int or "auto" or None, default = "auto"
The total number of threads blas and openmp functions should use during sampling. If set to None,
this will keep the default behavior of whatever blas implementation is used at runtime.
Setting it to "auto" will set it so that the total number of active blas threads is the
same as the `cores` argument. If set to an integer, the sampler will try to use that total
number of blas threads. If `blas_cores` is not divisible by `cores`, it might get rounded
down.
initvals : optional, dict, array of dict
Dict or list of dicts with initial value strategies to use instead of the defaults from
`Model.initial_values`. The keys should be names of transformed random variables.
Expand Down Expand Up @@ -644,6 +656,37 @@ def sample(
if chains is None:
chains = max(2, cores)

if blas_cores == "auto":
blas_cores = cores

cores = min(cores, chains)

if cores < 1:
raise ValueError("`cores` must be larger or equal to one")

if chains < 1:
raise ValueError("`chains` must be larger or equal to one")

if blas_cores is not None and blas_cores < 1:
raise ValueError("`blas_cores` must be larger or equal to one")

num_blas_cores_per_chain: int | None
joined_blas_limiter: Callable[[], Any]

if blas_cores is None:
joined_blas_limiter = contextlib.nullcontext
num_blas_cores_per_chain = None
elif isinstance(blas_cores, int):

def joined_blas_limiter():
return threadpool_limits(limits=blas_cores)

num_blas_cores_per_chain = blas_cores // cores
else:
raise ValueError(
f"Invalid argument `blas_cores`, must be int, 'auto' or None: {blas_cores}"
)

if random_seed == -1:
random_seed = None
random_seed_list = _get_seeds_per_chain(random_seed, chains)
Expand Down Expand Up @@ -685,21 +728,22 @@ def sample(
raise ValueError(
"Model can not be sampled with NUTS alone. Your model is probably not continuous."
)
return _sample_external_nuts(
sampler=nuts_sampler,
draws=draws,
tune=tune,
chains=chains,
target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
random_seed=random_seed,
initvals=initvals,
model=model,
var_names=var_names,
progressbar=progressbar,
idata_kwargs=idata_kwargs,
nuts_sampler_kwargs=nuts_sampler_kwargs,
**kwargs,
)
with joined_blas_limiter():
return _sample_external_nuts(
sampler=nuts_sampler,
draws=draws,
tune=tune,
chains=chains,
target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
random_seed=random_seed,
initvals=initvals,
model=model,
var_names=var_names,
progressbar=progressbar,
idata_kwargs=idata_kwargs,
nuts_sampler_kwargs=nuts_sampler_kwargs,
**kwargs,
)

if isinstance(step, list):
step = CompoundStep(step)
Expand All @@ -708,18 +752,19 @@ def sample(
nuts_kwargs = kwargs.pop("nuts")
[kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
_log.info("Auto-assigning NUTS sampler...")
initial_points, step = init_nuts(
init=init,
chains=chains,
n_init=n_init,
model=model,
random_seed=random_seed_list,
progressbar=progressbar,
jitter_max_retries=jitter_max_retries,
tune=tune,
initvals=initvals,
**kwargs,
)
with joined_blas_limiter():
initial_points, step = init_nuts(
init=init,
chains=chains,
n_init=n_init,
model=model,
random_seed=random_seed_list,
progressbar=progressbar,
jitter_max_retries=jitter_max_retries,
tune=tune,
initvals=initvals,
**kwargs,
)

if initial_points is None:
# Time to draw/evaluate numeric start points for each chain.
Expand Down Expand Up @@ -756,7 +801,8 @@ def sample(
)

sample_args = {
"draws": draws + tune, # FIXME: Why is tune added to draws?
# draws is now the total number of draws, including tuning
"draws": draws + tune,
"step": step,
"start": initial_points,
"traces": traces,
Expand All @@ -772,6 +818,7 @@ def sample(
}
parallel_args = {
"mp_ctx": mp_ctx,
"blas_cores": num_blas_cores_per_chain,
}

sample_args.update(kwargs)
Expand Down Expand Up @@ -817,11 +864,15 @@ def sample(
if has_population_samplers:
_log.info(f"Population sampling ({chains} chains)")
_print_step_hierarchy(step)
_sample_population(initial_points=initial_points, parallelize=cores > 1, **sample_args)
with joined_blas_limiter():
_sample_population(
initial_points=initial_points, parallelize=cores > 1, **sample_args
)
else:
_log.info(f"Sequential sampling ({chains} chains in 1 job)")
_print_step_hierarchy(step)
_sample_many(**sample_args)
with joined_blas_limiter():
_sample_many(**sample_args)

t_sampling = time.time() - t_start

Expand Down Expand Up @@ -1139,6 +1190,7 @@ def _mp_sample(
traces: Sequence[IBaseTrace],
model: Model | None = None,
callback: SamplingIteratorCallback | None = None,
blas_cores: int | None = None,
mp_ctx=None,
**kwargs,
) -> None:
Expand Down Expand Up @@ -1190,6 +1242,7 @@ def _mp_sample(
step_method=step,
progressbar=progressbar,
progressbar_theme=progressbar_theme,
blas_cores=blas_cores,
mp_ctx=mp_ctx,
)
try:
Expand Down
40 changes: 24 additions & 16 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from rich.console import Console
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.theme import Theme
from threadpoolctl import threadpool_limits

from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
Expand Down Expand Up @@ -93,6 +94,7 @@ def __init__(
draws: int,
tune: int,
seed,
blas_cores,
):
self._msg_pipe = msg_pipe
self._step_method = step_method
Expand All @@ -102,6 +104,7 @@ def __init__(
self._at_seed = seed + 1
self._draws = draws
self._tune = tune
self._blas_cores = blas_cores

def _unpickle_step_method(self):
unpickle_error = (
Expand All @@ -116,22 +119,23 @@ def _unpickle_step_method(self):
raise ValueError(unpickle_error)

def run(self):
try:
# We do not create this in __init__, as pickling this
# would destroy the shared memory.
self._unpickle_step_method()
self._point = self._make_numpy_refs()
self._start_loop()
except KeyboardInterrupt:
pass
except BaseException as e:
e = ExceptionWithTraceback(e, e.__traceback__)
# Send is not blocking so we have to force a wait for the abort
# message
self._msg_pipe.send(("error", e))
self._wait_for_abortion()
finally:
self._msg_pipe.close()
with threadpool_limits(limits=self._blas_cores):
try:
# We do not create this in __init__, as pickling this
# would destroy the shared memory.
self._unpickle_step_method()
self._point = self._make_numpy_refs()
self._start_loop()
except KeyboardInterrupt:
pass
except BaseException as e:
e = ExceptionWithTraceback(e, e.__traceback__)
# Send is not blocking so we have to force a wait for the abort
# message
self._msg_pipe.send(("error", e))
self._wait_for_abortion()
finally:
self._msg_pipe.close()

def _wait_for_abortion(self):
while True:
Expand Down Expand Up @@ -208,6 +212,7 @@ def __init__(
chain: int,
seed,
start: dict[str, np.ndarray],
blas_cores,
mp_ctx,
):
self.chain = chain
Expand Down Expand Up @@ -256,6 +261,7 @@ def __init__(
draws,
tune,
seed,
blas_cores,
),
)
self._process.start()
Expand Down Expand Up @@ -378,6 +384,7 @@ def __init__(
step_method,
progressbar: bool = True,
progressbar_theme: Theme | None = default_progress_theme,
blas_cores: int | None = None,
mp_ctx=None,
):
if any(len(arg) != chains for arg in [seeds, start_points]):
Expand Down Expand Up @@ -411,6 +418,7 @@ def __init__(
chain,
seed,
start,
blas_cores,
mp_ctx,
)
for chain, seed, start in zip(range(chains), seeds, start_points)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ pandas>=0.24.0
pytensor>=2.20,<2.21
rich>=13.7.1
scipy>=1.4.1
threadpoolctl>=3.1.0,<4.0.0
typing-extensions>=3.7.4
8 changes: 8 additions & 0 deletions tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,14 @@ def test_empty_model():
error.match("any free variables")


def test_blas_cores():
with pm.Model():
pm.Normal("a")
pm.sample(blas_cores="auto", tune=10, cores=2, draws=10)
pm.sample(blas_cores=None, tune=10, cores=2, draws=10)
pm.sample(blas_cores=2, tune=10, cores=2, draws=10)


def test_partial_trace_with_trace_unsupported():
with pm.Model() as model:
a = pm.Normal("a", mu=0, sigma=1)
Expand Down
2 changes: 2 additions & 0 deletions tests/sampling/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def test_explicit_sample(mp_start_method):
mp_ctx=ctx,
start={"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))},
step_method_pickled=step_method_pickled,
blas_cores=None,
)
proc.start()
while True:
Expand Down Expand Up @@ -193,6 +194,7 @@ def test_iterator():
start_points=[start] * 3,
step_method=step,
progressbar=False,
blas_cores=None,
)
with sampler:
for draw in sampler:
Expand Down

0 comments on commit 0cebb66

Please sign in to comment.