Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple convergence checking from SamplerReport #6453

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions pymc/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,18 @@
import warnings

from abc import ABC
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union, cast
from typing import (
Dict,
List,
Optional,
Sequence,
Set,
Sized,
Tuple,
TypeVar,
Union,
cast,
)

import numpy as np

Expand Down Expand Up @@ -510,7 +521,10 @@ def _squeeze_cat(results, combine, squeeze):
return results


def _choose_chains(traces: Sequence[BaseTrace], tune: int) -> Tuple[List[BaseTrace], int]:
S = TypeVar("S", bound=Sized)


def _choose_chains(traces: Sequence[S], tune: int) -> Tuple[List[S], int]:
Comment on lines +524 to +527
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This annotates it as returning a list of the same type of items as given in the input, but with the constraint that these items must be Sized.

"""
Filter and slice traces such that (n_traces * len(shortest_trace)) is maximized.

Expand Down
18 changes: 1 addition & 17 deletions pymc/backends/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,7 @@

from typing import Dict, List, Optional

import arviz

from pymc.stats.convergence import (
_LEVELS,
SamplerWarning,
log_warnings,
run_convergence_checks,
)
from pymc.stats.convergence import _LEVELS, SamplerWarning

logger = logging.getLogger("pymc")

Expand Down Expand Up @@ -73,22 +66,13 @@ def raise_ok(self, level="error"):
if errors:
raise ValueError("Serious convergence issues during sampling.")

def _run_convergence_checks(self, idata: arviz.InferenceData, model):
warnings = run_convergence_checks(idata, model)
self._add_warnings(warnings)

def _add_warnings(self, warnings, chain=None):
if chain is None:
warn_list = self._global_warnings
else:
warn_list = self._chain_warnings.setdefault(chain, [])
warn_list.extend(warnings)

def _log_summary(self):
for chain, warns in self._chain_warnings.items():
log_warnings(warns)
log_warnings(self._global_warnings)

def _slice(self, start, stop, step):
report = SamplerReport()

Expand Down
21 changes: 8 additions & 13 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@
from pymc.model import Model, modelcontext
from pymc.sampling.parallel import Draw, _cpu_count
from pymc.sampling.population import _sample_population
from pymc.stats.convergence import log_warning_stats, run_convergence_checks
from pymc.stats.convergence import (
log_warning_stats,
log_warnings,
run_convergence_checks,
)
from pymc.step_methods import NUTS, CompoundStep
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
from pymc.step_methods.hmc import quadpotential
Expand Down Expand Up @@ -602,7 +606,6 @@ def sample(
f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) "
f"took {t_sampling:.0f} seconds."
)
mtrace.report._log_summary()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inbetween the line 574 mtrace = MultiTrace(traces)[:length] where the MultiTrace was created, no warnings were added to mtrace.
Therefore, there are no warnings to log and the _log_summary() call can safely be removed.


idata = None
if compute_convergence_checks or return_inferencedata:
Expand All @@ -612,14 +615,9 @@ def sample(
idata = pm.to_inference_data(mtrace, **ikwargs)

if compute_convergence_checks:
if draws - tune < 100:
warnings.warn(
"The number of samples is too small to check convergence reliably.",
stacklevel=2,
)
Comment on lines -616 to -619
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now checked by run_convergence_checks, just like it already checked for a minimum number of chains

else:
convergence_warnings = run_convergence_checks(idata, model)
mtrace.report._add_warnings(convergence_warnings)
warns = run_convergence_checks(idata, model)
mtrace.report._add_warnings(warns)
log_warnings(warns)

if return_inferencedata:
# By default we drop the "warning" stat which contains `SamplerWarning`
Expand Down Expand Up @@ -925,9 +923,6 @@ def _mp_sample(
strace = traces[error._chain]
for strace in traces:
strace.close()

multitrace = MultiTrace(traces)
multitrace._report._log_summary()
Comment on lines -929 to -930
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here too: The multitrace can not have warnings that would be printed by _log_summary() because none were added here or in its __init__

raise
except KeyboardInterrupt:
pass
Expand Down
18 changes: 11 additions & 7 deletions pymc/smc/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
import warnings

from abc import ABC
from typing import Dict, cast
from typing import Dict, Union, cast

import numpy as np
import pytensor.tensor as at

from pytensor.graph.replace import clone_replace
from scipy.special import logsumexp
from scipy.stats import multivariate_normal
from typing_extensions import TypeAlias

from pymc.backends.ndarray import NDArray
from pymc.blocking import DictToArrayBijection
Expand All @@ -39,6 +40,9 @@
from pymc.step_methods.metropolis import MultivariateNormalProposal
from pymc.vartypes import discrete_types

SMCStats: TypeAlias = Dict[str, Union[int, float]]
SMCSettings: TypeAlias = Dict[str, Union[int, float]]


class SMC_KERNEL(ABC):
"""Base class for the Sequential Monte Carlo kernels.
Expand Down Expand Up @@ -304,7 +308,7 @@ def mutate(self):
"""Apply kernel-specific perturbation to the particles once per stage"""
pass

def sample_stats(self) -> Dict:
def sample_stats(self) -> SMCStats:
"""Stats to be saved at the end of each stage

These stats will be saved under `sample_stats` in the final InferenceData object.
Expand All @@ -314,7 +318,7 @@ def sample_stats(self) -> Dict:
"beta": self.beta,
}

def sample_settings(self) -> Dict:
def sample_settings(self) -> SMCSettings:
"""SMC_kernel settings to be saved once at the end of sampling.

These stats will be saved under `sample_stats` in the final InferenceData object.
Expand Down Expand Up @@ -425,7 +429,7 @@ def mutate(self):

self.acc_rate = np.mean(ac_)

def sample_stats(self):
def sample_stats(self) -> SMCStats:
stats = super().sample_stats()
stats.update(
{
Expand All @@ -434,7 +438,7 @@ def sample_stats(self):
)
return stats

def sample_settings(self):
def sample_settings(self) -> SMCSettings:
stats = super().sample_settings()
stats.update(
{
Expand Down Expand Up @@ -543,7 +547,7 @@ def mutate(self):

self.chain_acc_rate = np.mean(ac_, axis=0)

def sample_stats(self):
def sample_stats(self) -> SMCStats:
stats = super().sample_stats()
stats.update(
{
Expand All @@ -553,7 +557,7 @@ def sample_stats(self):
)
return stats

def sample_settings(self):
def sample_settings(self) -> SMCSettings:
stats = super().sample_settings()
stats.update(
{
Expand Down
43 changes: 20 additions & 23 deletions pymc/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from collections import defaultdict
from itertools import repeat
from typing import Any, Dict, Optional, Tuple, Union

import cloudpickle
import numpy as np
Expand All @@ -30,9 +31,10 @@

from pymc.backends.arviz import dict_to_dataset, to_inference_data
from pymc.backends.base import MultiTrace
from pymc.model import modelcontext
from pymc.model import Model, modelcontext
from pymc.sampling.parallel import _cpu_count
from pymc.smc.kernels import IMH
from pymc.stats.convergence import log_warnings, run_convergence_checks
from pymc.util import RandomState, _get_seeds_per_chain


Expand All @@ -50,7 +52,7 @@ def sample_smc(
idata_kwargs=None,
progressbar=True,
**kernel_kwargs,
):
) -> Union[InferenceData, MultiTrace]:
r"""
Sequential Monte Carlo based sampling.

Expand Down Expand Up @@ -236,20 +238,28 @@ def sample_smc(
)

if compute_convergence_checks:
_compute_convergence_checks(idata, draws, model, trace)
return idata if return_inferencedata else trace
if idata is None:
idata = to_inference_data(trace, log_likelihood=False)
warns = run_convergence_checks(idata, model)
trace.report._add_warnings(warns)
log_warnings(warns)
Comment on lines +241 to +245
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This replaces the _compute_convergence_checks function and makes the trace.report be a dead end that can easily be removed in the future

Remember from other changes:

  • "number of samples is too small" warning now done by run_convergence_checks
  • report._add_warnings was done inside report._run_convergence_checks
  • trace.report._log_summary() internally called log_warnings()


if return_inferencedata:
assert idata is not None
return idata
return trace


def _save_sample_stats(
sample_settings,
sample_stats,
chains,
trace,
return_inferencedata,
trace: MultiTrace,
return_inferencedata: bool,
_t_sampling,
idata_kwargs,
model,
):
model: Model,
) -> Tuple[Optional[Any], Optional[InferenceData]]:
sample_settings_dict = sample_settings[0]
sample_settings_dict["_t_sampling"] = _t_sampling
sample_stats_dict = sample_stats[0]
Expand All @@ -262,12 +272,12 @@ def _save_sample_stats(
value_list.append(chain_sample_stats[stat])
sample_stats_dict[stat] = value_list

idata: Optional[InferenceData] = None
if not return_inferencedata:
for stat, value in sample_stats_dict.items():
setattr(trace.report, stat, value)
for stat, value in sample_settings_dict.items():
setattr(trace.report, stat, value)
idata = None
else:
for stat, value in sample_stats_dict.items():
if chains > 1:
Expand All @@ -284,7 +294,7 @@ def _save_sample_stats(
library=pymc,
)

ikwargs = dict(model=model)
ikwargs: Dict[str, Any] = dict(model=model)
if idata_kwargs is not None:
ikwargs.update(idata_kwargs)
idata = to_inference_data(trace, **ikwargs)
Expand All @@ -293,19 +303,6 @@ def _save_sample_stats(
return sample_stats, idata


def _compute_convergence_checks(idata, draws, model, trace):
if draws < 100:
warnings.warn(
"The number of samples is too small to check convergence reliably.",
stacklevel=2,
)
else:
if idata is None:
idata = to_inference_data(trace, log_likelihood=False)
trace.report._run_convergence_checks(idata, model)
trace.report._log_summary()


def _sample_smc_int(
draws,
kernel,
Expand Down
9 changes: 6 additions & 3 deletions pymc/stats/convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,13 @@ def run_convergence_checks(idata: arviz.InferenceData, model) -> List[SamplerWar
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None, None, None)
return [warn]

if idata["posterior"].sizes["draw"] < 100:
msg = "The number of samples is too small to check convergence reliably."
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None, None, None)
return [warn]

if idata["posterior"].sizes["chain"] == 1:
msg = (
"Only one chain was sampled, this makes it impossible to " "run some convergence checks"
)
msg = "Only one chain was sampled, this makes it impossible to run some convergence checks"
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
return [warn]

Expand Down
8 changes: 4 additions & 4 deletions pymc/step_methods/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from abc import ABC, abstractmethod
from enum import IntEnum, unique
from typing import Dict, List, Sequence, Tuple, Union
from typing import Any, Dict, List, Mapping, Sequence, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -181,14 +181,14 @@ def flatten_steps(step: Union[BlockedStep, CompoundStep]) -> List[BlockedStep]:
class StatsBijection:
"""Map between a `list` of stats to `dict` of stats."""

def __init__(self, sampler_stats_dtypes: Sequence[Dict[str, type]]) -> None:
def __init__(self, sampler_stats_dtypes: Sequence[Mapping[str, type]]) -> None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typing rule of thumb: Generic input types, exact output types.

# Keep a list of flat vs. original stat names
self._stat_groups: List[List[Tuple[str, str]]] = [
[(f"sampler_{s}__{statname}", statname) for statname, _ in names_dtypes.items()]
for s, names_dtypes in enumerate(sampler_stats_dtypes)
]

def map(self, stats_list: StatsType) -> StatsDict:
def map(self, stats_list: Sequence[Mapping[str, Any]]) -> StatsDict:
"""Combine stats dicts of multiple samplers into one dict."""
stats_dict = {}
for s, sts in enumerate(stats_list):
Expand All @@ -197,7 +197,7 @@ def map(self, stats_list: StatsType) -> StatsDict:
stats_dict[sname] = sval
return stats_dict

def rmap(self, stats_dict: StatsDict) -> StatsType:
def rmap(self, stats_dict: Mapping[str, Any]) -> StatsType:
"""Split a global stats dict into a list of sampler-wise stats dicts."""
stats_list = []
for namemap in self._stat_groups:
Expand Down
11 changes: 5 additions & 6 deletions pymc/tests/smc/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import warnings

import numpy as np
Expand Down Expand Up @@ -215,13 +216,11 @@ def test_return_datatype(self, chains):
assert mt.nchains == chains
assert mt["x"].size == chains * draws

def test_convergence_checks(self):
with self.fast_model:
with pytest.warns(
UserWarning,
match="The number of samples is too small",
):
def test_convergence_checks(self, caplog):
with caplog.at_level(logging.INFO):
with self.fast_model:
pm.sample_smc(draws=99)
assert "The number of samples is too small" in caplog.text

def test_deprecated_parallel_arg(self):
with self.fast_model:
Expand Down