Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make forward sampling functions return InferenceData #5073

Merged
merged 22 commits into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
- The GLM submodule has been removed, please use [Bambi](https://bambinos.github.io/bambi/) instead.
- The `Distribution` keyword argument `testval` has been deprecated in favor of `initval`. Furthermore `initval` no longer assigns a `tag.test_value` on tensors since the initial values are now kept track of by the model object ([see #4913](https://github.com/pymc-devs/pymc/pull/4913)).
- `pm.sample` now returns results as `InferenceData` instead of `MultiTrace` by default (see [#4744](https://github.com/pymc-devs/pymc/pull/4744)).
- `pm.sample_prior_predictive`, `pm.sample_posterior_predictive` and `pm.sample_posterior_predictive_w` now return an `InferenceData` object
by default, instead of a dictionary (see [#5073](https://github.com/pymc-devs/pymc/pull/5073)).
- `pm.sample_prior_predictive` no longer returns transformed variable values by default. Pass them by name in `var_names` if you want to obtain these draws (see [4769](https://github.com/pymc-devs/pymc/pull/4769)).
- ⚠ `pm.Bound` interface no longer accepts a callable class as argument, instead it requires an instantiated distribution (created via the `.dist()` API) to be passed as an argument. In addition, Bound no longer returns a class instance but works as a normal PyMC distribution. Finally, it is no longer possible to do predictive random sampling from Bounded variables. Please, consult the new documentation for details on how to use Bounded variables (see [4815](https://github.com/pymc-devs/pymc/pull/4815)).
- `pm.DensityDist` no longer accepts the `logp` as its first position argument. It is now an optional keyword argument. If you pass a callable as the first positional argument, a `TypeError` will be raised (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
Expand Down
78 changes: 57 additions & 21 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,11 @@ def sample(
callback=None,
jitter_max_retries=10,
*,
return_inferencedata=None,
return_inferencedata=True,
idata_kwargs: dict = None,
mp_ctx=None,
**kwargs,
):
) -> Union[InferenceData, MultiTrace]:
r"""Draw samples from the posterior using the given step methods.

Multiple step methods are supported via compound step methods.
Expand Down Expand Up @@ -336,9 +336,9 @@ def sample(
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter
that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full``
init methods.
return_inferencedata : bool, default=True
return_inferencedata : bool
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False)
Defaults to `False`, but we'll switch to `True` in an upcoming release.
Defaults to `True`.
idata_kwargs : dict, optional
Keyword arguments for :func:`pymc.to_inference_data`
mp_ctx : multiprocessing.context.BaseContent
Expand Down Expand Up @@ -450,9 +450,6 @@ def sample(
if not isinstance(random_seed, abc.Iterable):
raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")

if return_inferencedata is None:
return_inferencedata = True

if not discard_tuned_samples and not return_inferencedata:
warnings.warn(
"Tuning samples will be included in the returned `MultiTrace` object, which can lead to"
Expand Down Expand Up @@ -1535,7 +1532,9 @@ def sample_posterior_predictive(
random_seed=None,
progressbar: bool = True,
mode: Optional[Union[str, Mode]] = None,
) -> Dict[str, np.ndarray]:
return_inferencedata=True,
idata_kwargs: dict = None,
) -> Union[InferenceData, Dict[str, np.ndarray]]:
"""Generate posterior predictive samples from a model given a trace.

Parameters
Expand Down Expand Up @@ -1570,12 +1569,17 @@ def sample_posterior_predictive(
time until completion ("expected time of arrival"; ETA).
mode:
The mode used by ``aesara.function`` to compile the graph.
return_inferencedata : bool
Whether to return an :class:`arviz:arviz.InferenceData` (True) object or a dictionary (False).
Defaults to True.
idata_kwargs : dict, optional
Keyword arguments for :func:`pymc.to_inference_data`

Returns
-------
samples : dict
Dictionary with the variable names as keys, and values numpy arrays containing
posterior predictive samples.
arviz.InferenceData or Dict
An ArviZ ``InferenceData`` object containing the posterior predictive samples (default), or
a dictionary with variable names as keys, and samples as numpy arrays.
"""

_trace: Union[MultiTrace, PointList]
Expand Down Expand Up @@ -1724,7 +1728,12 @@ def sample_posterior_predictive(
for k, ary in ppc_trace.items():
ppc_trace[k] = ary.reshape((nchain, len_trace, *ary.shape[1:]))

return ppc_trace
if not return_inferencedata:
return ppc_trace
ikwargs = dict(model=model)
if idata_kwargs:
ikwargs.update(idata_kwargs)
return pm.to_inference_data(posterior_predictive=ppc_trace, **ikwargs)


def sample_posterior_predictive_w(
Expand All @@ -1734,6 +1743,8 @@ def sample_posterior_predictive_w(
weights: Optional[ArrayLike] = None,
random_seed: Optional[int] = None,
progressbar: bool = True,
return_inferencedata=True,
idata_kwargs: dict = None,
):
"""Generate weighted posterior predictive samples from a list of models and
a list of traces according to a set of weights.
Expand All @@ -1760,12 +1771,18 @@ def sample_posterior_predictive_w(
Whether or not to display a progress bar in the command line. The bar shows the percentage
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
time until completion ("expected time of arrival"; ETA).
return_inferencedata : bool
Whether to return an :class:`arviz:arviz.InferenceData` (True) object or a dictionary (False).
Defaults to True.
idata_kwargs : dict, optional
Keyword arguments for :func:`pymc.to_inference_data`

Returns
-------
samples : dict
Dictionary with the variables as keys. The values corresponding to the
posterior predictive samples from the weighted models.
arviz.InferenceData or Dict
An ArviZ ``InferenceData`` object containing the posterior predictive samples from the
weighted models (default), or a dictionary with variable names as keys, and samples as
numpy arrays.
"""
if isinstance(traces[0], InferenceData):
n_samples = [
Expand Down Expand Up @@ -1884,7 +1901,13 @@ def sample_posterior_predictive_w(
except KeyboardInterrupt:
pass
else:
return {k: np.asarray(v) for k, v in ppc.items()}
ppc = {k: np.asarray(v) for k, v in ppc.items()}
if not return_inferencedata:
return ppc
ikwargs = dict(model=models)
if idata_kwargs:
ikwargs.update(idata_kwargs)
return pm.to_inference_data(posterior_predictive=ppc, **ikwargs)


def sample_prior_predictive(
Expand All @@ -1893,7 +1916,9 @@ def sample_prior_predictive(
var_names: Optional[Iterable[str]] = None,
random_seed=None,
mode: Optional[Union[str, Mode]] = None,
) -> Dict[str, np.ndarray]:
return_inferencedata=True,
idata_kwargs: dict = None,
) -> Union[InferenceData, Dict[str, np.ndarray]]:
"""Generate samples from the prior predictive distribution.

Parameters
Expand All @@ -1909,12 +1934,17 @@ def sample_prior_predictive(
Seed for the random number generator.
mode:
The mode used by ``aesara.function`` to compile the graph.
return_inferencedata : bool
Whether to return an :class:`arviz:arviz.InferenceData` (True) object or a dictionary (False).
Defaults to True.
idata_kwargs : dict, optional
Keyword arguments for :func:`pymc.to_inference_data`

Returns
-------
dict
Dictionary with variable names as keys. The values are numpy arrays of prior
samples.
arviz.InferenceData or Dict
An ArviZ ``InferenceData`` object containing the prior and prior predictive samples (default),
or a dictionary with variable names as keys and samples as numpy arrays.
"""
model = modelcontext(model)

Expand Down Expand Up @@ -1980,7 +2010,13 @@ def sample_prior_predictive(
for var_name in vars_:
if var_name in data:
prior[var_name] = data[var_name]
return prior

if not return_inferencedata:
return prior
ikwargs = dict(model=model)
if idata_kwargs:
ikwargs.update(idata_kwargs)
return pm.to_inference_data(prior=prior, **ikwargs)


def _init_jitter(model, point, chains, jitter_max_retries):
Expand Down
1 change: 1 addition & 0 deletions pymc/smc/smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def initialize_population(self) -> Dict[str, NDArray]:
self.draws,
var_names=[v.name for v in self.model.unobserved_value_vars],
model=self.model,
return_inferencedata=False,
)

def _initialize_kernel(self):
Expand Down
40 changes: 23 additions & 17 deletions pymc/tests/test_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,19 @@ def test_sample(self):
prior_trace1 = pm.sample_prior_predictive(1000)
pp_trace1 = pm.sample_posterior_predictive(idata, samples=1000)

assert prior_trace0["b"].shape == (1000,)
assert prior_trace0["obs"].shape == (1000, 100)
assert prior_trace1["obs"].shape == (1000, 200)
assert prior_trace0.prior["b"].shape == (1, 1000)
assert prior_trace0.prior_predictive["obs"].shape == (1, 1000, 100)
assert prior_trace1.prior_predictive["obs"].shape == (1, 1000, 200)

assert pp_trace0["obs"].shape == (1000, 100)

np.testing.assert_allclose(x, pp_trace0["obs"].mean(axis=0), atol=1e-1)

assert pp_trace1["obs"].shape == (1000, 200)
assert pp_trace0.posterior_predictive["obs"].shape == (1, 1000, 100)
np.testing.assert_allclose(
x, pp_trace0.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1
)

np.testing.assert_allclose(x_pred, pp_trace1["obs"].mean(axis=0), atol=1e-1)
assert pp_trace1.posterior_predictive["obs"].shape == (1, 1000, 200)
np.testing.assert_allclose(
x_pred, pp_trace1.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1
)

def test_sample_posterior_predictive_after_set_data(self):
with pm.Model() as model:
Expand All @@ -86,8 +88,10 @@ def test_sample_posterior_predictive_after_set_data(self):
pm.set_data(new_data={"x": x_test})
y_test = pm.sample_posterior_predictive(trace)

assert y_test["obs"].shape == (1000, 3)
np.testing.assert_allclose(x_test, y_test["obs"].mean(axis=0), atol=1e-1)
assert y_test.posterior_predictive["obs"].shape == (1, 1000, 3)
np.testing.assert_allclose(
x_test, y_test.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1
)

def test_sample_after_set_data(self):
with pm.Model() as model:
Expand Down Expand Up @@ -116,8 +120,10 @@ def test_sample_after_set_data(self):
)
pp_trace = pm.sample_posterior_predictive(new_idata, 1000)

assert pp_trace["obs"].shape == (1000, 3)
np.testing.assert_allclose(new_y, pp_trace["obs"].mean(axis=0), atol=1e-1)
assert pp_trace.posterior_predictive["obs"].shape == (1, 1000, 3)
np.testing.assert_allclose(
new_y, pp_trace.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1
)

def test_shared_data_as_index(self):
"""
Expand All @@ -130,7 +136,7 @@ def test_shared_data_as_index(self):
alpha = pm.Normal("alpha", 0, 1.5, size=3)
pm.Normal("obs", alpha[index], np.sqrt(1e-2), observed=y)

prior_trace = pm.sample_prior_predictive(1000, var_names=["alpha"])
prior_trace = pm.sample_prior_predictive(1000)
idata = pm.sample(
1000,
init=None,
Expand All @@ -146,10 +152,10 @@ def test_shared_data_as_index(self):
pm.set_data(new_data={"index": new_index, "y": new_y})
pp_trace = pm.sample_posterior_predictive(idata, 1000, var_names=["alpha", "obs"])

assert prior_trace["alpha"].shape == (1000, 3)
assert prior_trace.prior["alpha"].shape == (1, 1000, 3)
assert idata.posterior["alpha"].shape == (1, 1000, 3)
assert pp_trace["alpha"].shape == (1000, 3)
assert pp_trace["obs"].shape == (1000, 3)
assert pp_trace.posterior_predictive["alpha"].shape == (1, 1000, 3)
assert pp_trace.posterior_predictive["obs"].shape == (1, 1000, 3)

def test_shared_data_as_rv_input(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3249,7 +3249,7 @@ def test_distinct_rvs():
X_rv = pm.Normal("x")
Y_rv = pm.Normal("y")

pp_samples = pm.sample_prior_predictive(samples=2)
pp_samples = pm.sample_prior_predictive(samples=2, return_inferencedata=False)

assert X_rv.owner.inputs[0] != Y_rv.owner.inputs[0]

Expand All @@ -3259,7 +3259,7 @@ def test_distinct_rvs():
X_rv = pm.Normal("x")
Y_rv = pm.Normal("y")

pp_samples_2 = pm.sample_prior_predictive(samples=2)
pp_samples_2 = pm.sample_prior_predictive(samples=2, return_inferencedata=False)

assert np.array_equal(pp_samples["y"], pp_samples_2["y"])

Expand Down
6 changes: 3 additions & 3 deletions pymc/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,7 +1583,7 @@ def ref_rand(mu, rowcov, colcov):
rowcov=np.eye(3),
colcov=np.eye(3),
)
check = pm.sample_prior_predictive(n_fails)
check = pm.sample_prior_predictive(n_fails, return_inferencedata=False)

ref_smp = ref_rand(mu=np.random.random((3, 3)), rowcov=np.eye(3), colcov=np.eye(3))

Expand Down Expand Up @@ -1921,7 +1921,7 @@ def sample_prior(self, distribution, shape, nested_rvs_info, prior_samples):
nested_rvs_info,
)
with model:
return pm.sample_prior_predictive(prior_samples)
return pm.sample_prior_predictive(prior_samples, return_inferencedata=False)

@pytest.mark.parametrize(
["prior_samples", "shape", "mu", "alpha"],
Expand Down Expand Up @@ -2379,7 +2379,7 @@ def test_car_rng_fn(sparse):
with pm.Model(rng_seeder=1):
car = pm.CAR("car", mu, W, alpha, tau, size=size)
mn = pm.MvNormal("mn", mu, cov, size=size)
check = pm.sample_prior_predictive(n_fails)
check = pm.sample_prior_predictive(n_fails, return_inferencedata=False)

p, f = delta, n_fails
while p <= delta and f > 0:
Expand Down
Loading