Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace fastprogress progress bars with rich #7233

Merged
merged 11 commits into from
Apr 3, 2024
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ dependencies:
- blas
- cachetools>=4.2.1
- cloudpickle
- fastprogress>=0.2.0
- h5py>=2.7
- numpy>=1.15.0
- pandas>=0.24.0
Expand All @@ -28,6 +27,7 @@ dependencies:
- pre-commit>=2.8.0
- pytest-cov>=2.5
- pytest>=3.0
- rich>=13.7.1
- sphinx-copybutton
- sphinx-design
- sphinx-notfound-page
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ dependencies:
- arviz>=0.13.0
- cachetools>=4.2.1
- cloudpickle
- fastprogress>=0.2.0
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.19,<2.20
- python-graphviz
- rich>=13.7.1
- scipy>=1.4.1
- typing-extensions>=3.7.4
# Extra dependencies for docs build
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-jax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ dependencies:
- blas
- cachetools>=4.2.1
- cloudpickle
- fastprogress>=0.2.0
- h5py>=2.7
# Jaxlib version must not be greater than jax version!
- blackjax>=1.0.0
Expand All @@ -24,6 +23,7 @@ dependencies:
- pytensor>=2.19,<2.20
- python-graphviz
- networkx
- rich>=13.7.1
- scipy>=1.4.1
- typing-extensions>=3.7.4
# Extra dependencies for testing
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ dependencies:
- blas
- cachetools>=4.2.1
- cloudpickle
- fastprogress>=0.2.0
- h5py>=2.7
- jax
- libblas=*=*mkl
Expand All @@ -20,6 +19,7 @@ dependencies:
- pytensor>=2.19,<2.20
- python-graphviz
- networkx
- rich>=13.7.1
- scipy>=1.4.1
- typing-extensions>=3.7.4
# Extra dependencies for testing
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ dependencies:
- blas
- cachetools>=4.2.1
- cloudpickle
- fastprogress>=0.2.0
- h5py>=2.7
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.19,<2.20
- python-graphviz
- networkx
- rich>=13.7.1
- scipy>=1.4.1
- typing-extensions>=3.7.4
# Extra dependencies for dev, testing and docs build
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ dependencies:
- blas
- cachetools>=4.2.1
- cloudpickle
- fastprogress>=0.2.0
- h5py>=2.7
- libpython
- mkl-service>=2.3.0
Expand All @@ -20,6 +19,7 @@ dependencies:
- pytensor>=2.19,<2.20
- python-graphviz
- networkx
- rich>=13.7.1
- scipy>=1.4.1
- typing-extensions>=3.7.4
# Extra dependencies for testing
Expand Down
48 changes: 26 additions & 22 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import xarray

from arviz import InferenceData
from fastprogress.fastprogress import progress_bar
from pytensor import tensor as pt
from pytensor.graph.basic import (
Apply,
Expand All @@ -46,6 +45,9 @@
RandomStateSharedVariable,
)
from pytensor.tensor.sharedvar import SharedVariable
from rich.console import Console
from rich.progress import Progress
from rich.theme import Theme
from typing_extensions import TypeAlias

import pymc as pm
Expand All @@ -59,6 +61,7 @@
RandomState,
_get_seeds_per_chain,
dataset_to_point_list,
default_progress_theme,
get_default_varnames,
point_wrapper,
)
Expand All @@ -70,7 +73,6 @@
"sample_posterior_predictive",
)


ArrayLike: TypeAlias = Union[np.ndarray, list[float]]
PointList: TypeAlias = list[PointType]

Expand Down Expand Up @@ -442,6 +444,7 @@ def sample_posterior_predictive(
sample_dims: Optional[list[str]] = None,
random_seed: RandomState = None,
progressbar: bool = True,
progressbar_theme: Optional[Theme] = default_progress_theme,
return_inferencedata: bool = True,
extend_inferencedata: bool = False,
predictions: bool = False,
Expand Down Expand Up @@ -796,10 +799,6 @@ def sample_posterior_predictive(
else:
vars_ = model.observed_RVs + observed_dependent_deterministics(model)

indices = np.arange(samples)
if progressbar:
indices = progress_bar(indices, total=samples, display=progressbar)

vars_to_sample = list(get_default_varnames(vars_, include_transformed=False))

if not vars_to_sample:
Expand Down Expand Up @@ -834,25 +833,30 @@ def sample_posterior_predictive(
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
ppc_trace_t = _DefaultTrace(samples)
try:
for idx in indices:
if nchain > 1:
# the trace object will either be a MultiTrace (and have _straces)...
if hasattr(_trace, "_straces"):
chain_idx, point_idx = np.divmod(idx, len_trace)
chain_idx = chain_idx % nchain
param = cast(MultiTrace, _trace)._straces[chain_idx].point(point_idx)
# ... or a PointList
with Progress(console=Console(theme=progressbar_theme)) as progress:
task = progress.add_task("Sampling ...", total=samples, visible=progressbar)
for idx in np.arange(samples):
if nchain > 1:
# the trace object will either be a MultiTrace (and have _straces)...
if hasattr(_trace, "_straces"):
chain_idx, point_idx = np.divmod(idx, len_trace)
chain_idx = chain_idx % nchain
param = cast(MultiTrace, _trace)._straces[chain_idx].point(point_idx)
# ... or a PointList
else:
param = cast(PointList, _trace)[idx % (len_trace * nchain)]
# there's only a single chain, but the index might hit it multiple times if
# the number of indices is greater than the length of the trace.
else:
param = cast(PointList, _trace)[idx % (len_trace * nchain)]
# there's only a single chain, but the index might hit it multiple times if
# the number of indices is greater than the length of the trace.
else:
param = _trace[idx % len_trace]
param = _trace[idx % len_trace]

values = sampler_fn(**param)

for k, v in zip(vars_, values):
ppc_trace_t.insert(k.name, v, idx)

values = sampler_fn(**param)
progress.advance(task)

for k, v in zip(vars_, values):
ppc_trace_t.insert(k.name, v, idx)
except KeyboardInterrupt:
pass

Expand Down
39 changes: 25 additions & 14 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@

from arviz import InferenceData, dict_to_dataset
from arviz.data.base import make_attrs
from fastprogress.fastprogress import progress_bar
from pytensor.graph.basic import Variable
from rich.console import Console
from rich.progress import Progress
from rich.theme import Theme
from typing_extensions import Protocol, TypeAlias

import pymc as pm
Expand Down Expand Up @@ -65,6 +67,7 @@
RandomSeed,
RandomState,
_get_seeds_per_chain,
default_progress_theme,
drop_warning_stat,
get_untransformed_name,
is_transformed_name,
Expand Down Expand Up @@ -377,6 +380,7 @@ def sample(
cores: Optional[int] = None,
random_seed: RandomState = None,
progressbar: bool = True,
progressbar_theme: Optional[Theme] = default_progress_theme,
step=None,
var_names: Optional[Sequence[str]] = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
Expand Down Expand Up @@ -406,6 +410,7 @@ def sample(
cores: Optional[int] = None,
random_seed: RandomState = None,
progressbar: bool = True,
progressbar_theme: Optional[Theme] = default_progress_theme,
step=None,
var_names: Optional[Sequence[str]] = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
Expand Down Expand Up @@ -435,6 +440,7 @@ def sample(
cores: Optional[int] = None,
random_seed: RandomState = None,
progressbar: bool = True,
progressbar_theme: Optional[Theme] = default_progress_theme,
step=None,
var_names: Optional[Sequence[str]] = None,
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
Expand Down Expand Up @@ -761,6 +767,7 @@ def sample(
"tune": tune,
"var_names": var_names,
"progressbar": progressbar,
"progressbar_theme": progressbar_theme,
"model": model,
"cores": cores,
"callback": callback,
Expand Down Expand Up @@ -983,6 +990,7 @@ def _sample(
trace: IBaseTrace,
tune: int,
model: Optional[Model] = None,
progressbar_theme: Optional[Theme] = default_progress_theme,
callback=None,
**kwargs,
) -> None:
Expand Down Expand Up @@ -1010,6 +1018,8 @@ def _sample(
tune : int
Number of iterations to tune.
model : Model (optional if in ``with`` context)
progressbar_theme : Theme
Optional custom theme for the progress bar.
"""
skip_first = kwargs.get("skip_first", 0)

Expand All @@ -1026,19 +1036,16 @@ def _sample(
)
_pbar_data = {"chain": chain, "divergences": 0}
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
if progressbar:
sampling = progress_bar(sampling_gen, total=draws, display=progressbar)
sampling.comment = _desc.format(**_pbar_data)
else:
sampling = sampling_gen
try:
for it, diverging in enumerate(sampling):
if it >= skip_first and diverging:
_pbar_data["divergences"] += 1
if progressbar:
sampling.comment = _desc.format(**_pbar_data)
except KeyboardInterrupt:
pass
with Progress(console=Console(theme=progressbar_theme)) as progress:
try:
task = progress.add_task(_desc.format(**_pbar_data), total=draws, visible=progressbar)
for it, diverging in enumerate(sampling_gen):
if it >= skip_first and diverging:
_pbar_data["divergences"] += 1
progress.update(task, advance=1)
progress.update(task, advance=1, completed=True)
except KeyboardInterrupt:
pass


def _iter_sample(
Expand Down Expand Up @@ -1131,6 +1138,7 @@ def _mp_sample(
random_seed: Sequence[RandomSeed],
start: Sequence[PointType],
progressbar: bool = True,
progressbar_theme: Optional[Theme] = default_progress_theme,
traces: Sequence[IBaseTrace],
model: Optional[Model] = None,
callback: Optional[SamplingIteratorCallback] = None,
Expand Down Expand Up @@ -1158,6 +1166,8 @@ def _mp_sample(
Dicts must contain numeric (transformed) initial values for all (transformed) free variables.
progressbar : bool
Whether or not to display a progress bar in the command line.
progressbar_theme : Theme
Optional custom theme for the progress bar.
traces
Recording backends for each chain.
model : Model (optional if in ``with`` context)
Expand All @@ -1182,6 +1192,7 @@ def _mp_sample(
start_points=start,
step_method=step,
progressbar=progressbar,
progressbar_theme=progressbar_theme,
mp_ctx=mp_ctx,
)
try:
Expand Down
Loading
Loading