-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Decouple convergence checking from SamplerReport
#6453
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,7 +40,11 @@ | |
from pymc.model import Model, modelcontext | ||
from pymc.sampling.parallel import Draw, _cpu_count | ||
from pymc.sampling.population import _sample_population | ||
from pymc.stats.convergence import log_warning_stats, run_convergence_checks | ||
from pymc.stats.convergence import ( | ||
log_warning_stats, | ||
log_warnings, | ||
run_convergence_checks, | ||
) | ||
from pymc.step_methods import NUTS, CompoundStep | ||
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared | ||
from pymc.step_methods.hmc import quadpotential | ||
|
@@ -602,7 +606,6 @@ def sample( | |
f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) " | ||
f"took {t_sampling:.0f} seconds." | ||
) | ||
mtrace.report._log_summary() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inbetween the line 574 |
||
|
||
idata = None | ||
if compute_convergence_checks or return_inferencedata: | ||
|
@@ -612,14 +615,9 @@ def sample( | |
idata = pm.to_inference_data(mtrace, **ikwargs) | ||
|
||
if compute_convergence_checks: | ||
if draws - tune < 100: | ||
warnings.warn( | ||
"The number of samples is too small to check convergence reliably.", | ||
stacklevel=2, | ||
) | ||
Comment on lines
-616
to
-619
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is now checked by |
||
else: | ||
convergence_warnings = run_convergence_checks(idata, model) | ||
mtrace.report._add_warnings(convergence_warnings) | ||
warns = run_convergence_checks(idata, model) | ||
mtrace.report._add_warnings(warns) | ||
log_warnings(warns) | ||
|
||
if return_inferencedata: | ||
# By default we drop the "warning" stat which contains `SamplerWarning` | ||
|
@@ -925,9 +923,6 @@ def _mp_sample( | |
strace = traces[error._chain] | ||
for strace in traces: | ||
strace.close() | ||
|
||
multitrace = MultiTrace(traces) | ||
multitrace._report._log_summary() | ||
Comment on lines
-929
to
-930
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here too: The |
||
raise | ||
except KeyboardInterrupt: | ||
pass | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
|
||
from collections import defaultdict | ||
from itertools import repeat | ||
from typing import Any, Dict, Optional, Tuple, Union | ||
|
||
import cloudpickle | ||
import numpy as np | ||
|
@@ -30,9 +31,10 @@ | |
|
||
from pymc.backends.arviz import dict_to_dataset, to_inference_data | ||
from pymc.backends.base import MultiTrace | ||
from pymc.model import modelcontext | ||
from pymc.model import Model, modelcontext | ||
from pymc.sampling.parallel import _cpu_count | ||
from pymc.smc.kernels import IMH | ||
from pymc.stats.convergence import log_warnings, run_convergence_checks | ||
from pymc.util import RandomState, _get_seeds_per_chain | ||
|
||
|
||
|
@@ -50,7 +52,7 @@ def sample_smc( | |
idata_kwargs=None, | ||
progressbar=True, | ||
**kernel_kwargs, | ||
): | ||
) -> Union[InferenceData, MultiTrace]: | ||
r""" | ||
Sequential Monte Carlo based sampling. | ||
|
||
|
@@ -236,20 +238,28 @@ def sample_smc( | |
) | ||
|
||
if compute_convergence_checks: | ||
_compute_convergence_checks(idata, draws, model, trace) | ||
return idata if return_inferencedata else trace | ||
if idata is None: | ||
idata = to_inference_data(trace, log_likelihood=False) | ||
warns = run_convergence_checks(idata, model) | ||
trace.report._add_warnings(warns) | ||
log_warnings(warns) | ||
Comment on lines
+241
to
+245
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This replaces the Remember from other changes:
|
||
|
||
if return_inferencedata: | ||
assert idata is not None | ||
return idata | ||
return trace | ||
|
||
|
||
def _save_sample_stats( | ||
sample_settings, | ||
sample_stats, | ||
chains, | ||
trace, | ||
return_inferencedata, | ||
trace: MultiTrace, | ||
return_inferencedata: bool, | ||
_t_sampling, | ||
idata_kwargs, | ||
model, | ||
): | ||
model: Model, | ||
) -> Tuple[Optional[Any], Optional[InferenceData]]: | ||
sample_settings_dict = sample_settings[0] | ||
sample_settings_dict["_t_sampling"] = _t_sampling | ||
sample_stats_dict = sample_stats[0] | ||
|
@@ -262,12 +272,12 @@ def _save_sample_stats( | |
value_list.append(chain_sample_stats[stat]) | ||
sample_stats_dict[stat] = value_list | ||
|
||
idata: Optional[InferenceData] = None | ||
if not return_inferencedata: | ||
for stat, value in sample_stats_dict.items(): | ||
setattr(trace.report, stat, value) | ||
for stat, value in sample_settings_dict.items(): | ||
setattr(trace.report, stat, value) | ||
idata = None | ||
else: | ||
for stat, value in sample_stats_dict.items(): | ||
if chains > 1: | ||
|
@@ -284,7 +294,7 @@ def _save_sample_stats( | |
library=pymc, | ||
) | ||
|
||
ikwargs = dict(model=model) | ||
ikwargs: Dict[str, Any] = dict(model=model) | ||
if idata_kwargs is not None: | ||
ikwargs.update(idata_kwargs) | ||
idata = to_inference_data(trace, **ikwargs) | ||
|
@@ -293,19 +303,6 @@ def _save_sample_stats( | |
return sample_stats, idata | ||
|
||
|
||
def _compute_convergence_checks(idata, draws, model, trace): | ||
if draws < 100: | ||
warnings.warn( | ||
"The number of samples is too small to check convergence reliably.", | ||
stacklevel=2, | ||
) | ||
else: | ||
if idata is None: | ||
idata = to_inference_data(trace, log_likelihood=False) | ||
trace.report._run_convergence_checks(idata, model) | ||
trace.report._log_summary() | ||
|
||
|
||
def _sample_smc_int( | ||
draws, | ||
kernel, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ | |
|
||
from abc import ABC, abstractmethod | ||
from enum import IntEnum, unique | ||
from typing import Dict, List, Sequence, Tuple, Union | ||
from typing import Any, Dict, List, Mapping, Sequence, Tuple, Union | ||
|
||
import numpy as np | ||
|
||
|
@@ -181,14 +181,14 @@ def flatten_steps(step: Union[BlockedStep, CompoundStep]) -> List[BlockedStep]: | |
class StatsBijection: | ||
"""Map between a `list` of stats to `dict` of stats.""" | ||
|
||
def __init__(self, sampler_stats_dtypes: Sequence[Dict[str, type]]) -> None: | ||
def __init__(self, sampler_stats_dtypes: Sequence[Mapping[str, type]]) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typing rule of thumb: Generic input types, exact output types. |
||
# Keep a list of flat vs. original stat names | ||
self._stat_groups: List[List[Tuple[str, str]]] = [ | ||
[(f"sampler_{s}__{statname}", statname) for statname, _ in names_dtypes.items()] | ||
for s, names_dtypes in enumerate(sampler_stats_dtypes) | ||
] | ||
|
||
def map(self, stats_list: StatsType) -> StatsDict: | ||
def map(self, stats_list: Sequence[Mapping[str, Any]]) -> StatsDict: | ||
"""Combine stats dicts of multiple samplers into one dict.""" | ||
stats_dict = {} | ||
for s, sts in enumerate(stats_list): | ||
|
@@ -197,7 +197,7 @@ def map(self, stats_list: StatsType) -> StatsDict: | |
stats_dict[sname] = sval | ||
return stats_dict | ||
|
||
def rmap(self, stats_dict: StatsDict) -> StatsType: | ||
def rmap(self, stats_dict: Mapping[str, Any]) -> StatsType: | ||
"""Split a global stats dict into a list of sampler-wise stats dicts.""" | ||
stats_list = [] | ||
for namemap in self._stat_groups: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This annotates it as returning a list of the same type of items as given in the input, but with the constraint that these items must be
Sized
.