Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pymc3 Plot & Diagnostics & Arviz Dependency #4397

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions benchmarks/benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import timeit

import arviz as az
import numpy as np
import pandas as pd
import theano
Expand Down Expand Up @@ -192,7 +192,7 @@ def track_glm_hierarchical_ess(self, init):
compute_convergence_checks=False,
)
tot = time.time() - t0
ess = float(pm.ess(trace, var_names=["mu_a"])["mu_a"].values)
ess = float(az.ess(trace, var_names=["mu_a"])["mu_a"].values)
return ess / tot

def track_marginal_mixture_model_ess(self, init):
Expand All @@ -214,7 +214,7 @@ def track_marginal_mixture_model_ess(self, init):
compute_convergence_checks=False,
)
tot = time.time() - t0
ess = pm.ess(trace, var_names=["mu"])["mu"].values.min() # worst case
ess = az.ess(trace, var_names=["mu"])["mu"].values.min() # worst case
return ess / tot


Expand Down Expand Up @@ -245,7 +245,7 @@ def track_glm_hierarchical_ess(self, step):
compute_convergence_checks=False,
)
tot = time.time() - t0
ess = float(pm.ess(trace, var_names=["mu_a"])["mu_a"].values)
ess = float(az.ess(trace, var_names=["mu_a"])["mu_a"].values)
return ess / tot


Expand Down Expand Up @@ -304,7 +304,7 @@ def freefall(y, t, p):
t0 = time.time()
trace = pm.sample(500, tune=1000, chains=2, cores=2, random_seed=0)
tot = time.time() - t0
ess = pm.ess(trace)
ess = az.ess(trace)
return np.mean([ess.sigma, ess.gamma]) / tot


Expand Down
1 change: 0 additions & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ API Reference
api/shape_utils
api/ode


Indices and tables
===================

Expand Down
15 changes: 3 additions & 12 deletions docs/source/api/plots.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,7 @@ Plots are delegated to the
`ArviZ <https://arviz-devs.github.io/arviz/index.html>`_.
library, a general purpose library for
"exploratory analysis of Bayesian models."
For plots, ``pymc3.<function>`` are now aliases
for ArviZ functions. Thus, the links below will redirect you to
ArviZ docs:
Refer to its documentation to use the plotting functions directly.

- :func:`pymc3.traceplot <arviz:arviz.plot_trace>`
- :func:`pymc3.plot_posterior <arviz:arviz.plot_posterior>`
- :func:`pymc3.forestplot <arviz:arviz.plot_forest>`
- :func:`pymc3.compareplot <arviz:arviz.plot_compare>`
- :func:`pymc3.autocorrplot <arviz:arviz.plot_autocorr>`
- :func:`pymc3.energyplot <arviz:arviz.plot_energy>`
- :func:`pymc3.kdeplot <arviz:arviz.plot_kde>`
- :func:`pymc3.densityplot <arviz:arviz.plot_density>`
- :func:`pymc3.pairplot <arviz:arviz.plot_pair>`
.. automodule:: pymc3.plots.posteriorplot
:members:
19 changes: 1 addition & 18 deletions docs/source/api/stats.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,4 @@ Statistics and diagnostics are delegated to the
`ArviZ <https://arviz-devs.github.io/arviz/index.html>`_.
library, a general purpose library for
"exploratory analysis of Bayesian models."
For statistics and diagnostics, ``pymc3.<function>`` are now aliases
for ArviZ functions. Thus, the links below will redirect you to
ArviZ docs:

.. currentmodule:: pymc3.stats


- :func:`pymc3.bfmi <arviz:arviz.bfmi>`
- :func:`pymc3.compare <arviz:arviz.compare>`
- :func:`pymc3.ess <arviz:arviz.ess>`
- :data:`pymc3.geweke <arviz:arviz.geweke>`
- :func:`pymc3.hpd <arviz:arviz.hpd>`
- :func:`pymc3.loo <arviz:arviz.loo>`
- :func:`pymc3.mcse <arviz:arviz.mcse>`
- :func:`pymc3.r2_score <arviz:arviz.r2_score>`
- :func:`pymc3.rhat <arviz:arviz.rhat>`
- :func:`pymc3.summary <arviz:arviz.summary>`
- :func:`pymc3.waic <arviz:arviz.waic>`
Refer to its documentation to use the diagnostics functions directly.
1 change: 0 additions & 1 deletion pymc3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __set_compiler_flags():
from pymc3.plots import *
from pymc3.sampling import *
from pymc3.smc import *
from pymc3.stats import *
from pymc3.step_methods import *
from pymc3.tests import test
from pymc3.theanof import *
Expand Down
101 changes: 5 additions & 96 deletions pymc3/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,108 +14,17 @@

"""PyMC3 Plotting.

Plots are delegated to the ArviZ library, a general purpose library for
"exploratory analysis of Bayesian models." See https://arviz-devs.github.io/arviz/
for details on plots.
Plots are delegated to the `ArviZ <https://arviz-devs.github.io/arviz/>`_ library, a general purpose library for
exploratory analysis of Bayesian models. For more details, see https://arviz-devs.github.io/arviz/.

Only `plot_posterior_predictive_glm` is kept in the PyMC code base for now, but it will move to ArviZ once the latter adds features for regression plots.
"""
import functools
import sys
import warnings

import arviz as az


def map_args(func):
swaps = [("varnames", "var_names")]

@functools.wraps(func)
def wrapped(*args, **kwargs):
for (old, new) in swaps:
if old in kwargs and new not in kwargs:
warnings.warn(
f"Keyword argument `{old}` renamed to `{new}`, and will be removed in pymc3 3.8"
)
kwargs[new] = kwargs.pop(old)
return func(*args, **kwargs)

return wrapped


# pymc3 custom plots: override these names for custom behavior
autocorrplot = map_args(az.plot_autocorr)
forestplot = map_args(az.plot_forest)
kdeplot = map_args(az.plot_kde)
plot_posterior = map_args(az.plot_posterior)
energyplot = map_args(az.plot_energy)
densityplot = map_args(az.plot_density)
pairplot = map_args(az.plot_pair)

# Use compact traceplot by default
@map_args
@functools.wraps(az.plot_trace)
def traceplot(*args, **kwargs):
try:
kwargs.setdefault("compact", True)
return az.plot_trace(*args, **kwargs)
except TypeError:
kwargs.pop("compact")
return az.plot_trace(*args, **kwargs)


# addition arg mapping for compare plot
@functools.wraps(az.plot_compare)
def compareplot(*args, **kwargs):
if "comp_df" in kwargs:
comp_df = kwargs["comp_df"].copy()
else:
args = list(args)
comp_df = args[0].copy()
if "WAIC" in comp_df.columns:
comp_df = comp_df.rename(
index=str,
columns={
"WAIC": "waic",
"pWAIC": "p_waic",
"dWAIC": "d_waic",
"SE": "se",
"dSE": "dse",
"var_warn": "warning",
},
)
elif "LOO" in comp_df.columns:
comp_df = comp_df.rename(
index=str,
columns={
"LOO": "loo",
"pLOO": "p_loo",
"dLOO": "d_loo",
"SE": "se",
"dSE": "dse",
"shape_warn": "warning",
},
)
if "comp_df" in kwargs:
kwargs["comp_df"] = comp_df
else:
args[0] = comp_df
return az.plot_compare(*args, **kwargs)


from pymc3.plots.posteriorplot import plot_posterior_predictive_glm

# Access to arviz plots: base plots provided by arviz
for plot in az.plots.__all__:
setattr(sys.modules[__name__], plot, map_args(getattr(az.plots, plot)))

__all__ = tuple(az.plots.__all__) + (
"autocorrplot",
"compareplot",
"forestplot",
"kdeplot",
"plot_posterior",
"traceplot",
"energyplot",
"densityplot",
"pairplot",
"plot_posterior_predictive_glm",
)
__all__ = ["plot_posterior_predictive_glm"]
41 changes: 28 additions & 13 deletions pymc3/plots/posteriorplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from __future__ import annotations

import warnings

from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import matplotlib.pyplot as plt
Expand All @@ -33,20 +35,33 @@ def plot_posterior_predictive_glm(
**kwargs: Any
) -> None:
"""Plot posterior predictive of a linear model.
:Arguments:
trace: InferenceData or MultiTrace
Output of pm.sample()
eval: <array>
Array over which to evaluate lm
lm: function <default: linear function>
Function mapping parameters at different points
to their respective outputs.
input: point, sample
output: estimated value
samples: int <default=30>
How many posterior samples to draw.
Additional keyword arguments are passed to pylab.plot().

Parameters
CloudChaoszero marked this conversation as resolved.
Show resolved Hide resolved
----------
trace: InferenceData or MultiTrace
Output of pm.sample()
eval: <array>
Array over which to evaluate lm
lm: function <default: linear function>
Function mapping parameters at different points
to their respective outputs.
input: point, sample
output: estimated value
samples: int <default=30>
How many posterior samples to draw.
kwargs : mapping, optional
Additional keyword arguments are passed to ``matplotlib.pyplot.plot()``.

Warnings
--------
The `plot_posterior_predictive_glm` function will be removed in a future PyMC3 release.
"""
warnings.warn(
"The `plot_posterior_predictive_glm` function will migrate to Arviz in a future release. "
"\nKeep up to date with `ArviZ <https://arviz-devs.github.io/arviz/>`_ for future updates.",
DeprecationWarning,
)

if lm is None:
lm = lambda x, sample: sample["Intercept"] + sample["x"] * x

Expand Down
2 changes: 1 addition & 1 deletion pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def sample(
...: y = pm.Binomial("y", n=n, p=p, observed=h)
...: trace = pm.sample()

In [3]: pm.summary(trace, kind="stats")
In [3]: az.summary(trace, kind="stats")

Out[3]:
mean sd hdi_3% hdi_97%
Expand Down
69 changes: 0 additions & 69 deletions pymc3/stats/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion pymc3/step_methods/mlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ class MLDA(ArrayStepShared):
... tune=100, step=step_method,
... random_seed=123)
...
... pm.summary(trace, kind="stats")
... az.summary(trace, kind="stats")
mean sd hdi_3% hdi_97%
x 0.99 0.987 -0.734 2.992

Expand Down
5 changes: 3 additions & 2 deletions pymc3/tests/sampler_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import arviz as az
import numpy as np
import numpy.testing as npt
import theano.tensor as tt
Expand Down Expand Up @@ -146,12 +147,12 @@ def setup_class(cls):

def test_neff(self):
if hasattr(self, "min_n_eff"):
n_eff = pm.ess(self.trace[self.burn :])
n_eff = az.ess(self.trace[self.burn :])
for var in n_eff:
npt.assert_array_less(self.min_n_eff, n_eff[var])

def test_Rhat(self):
rhat = pm.rhat(self.trace[self.burn :])
rhat = az.rhat(self.trace[self.burn :])
for var in rhat:
npt.assert_allclose(rhat[var], 1, rtol=0.01)

Expand Down
Loading