Skip to content

Commit

Permalink
improve ABC sampler (#3940)
Browse files Browse the repository at this point in the history
* Expand ABC features.

* valueerror

* update notebook

* remove unused import update release notes

* fix notebook style and change order params argument
  • Loading branch information
aloctavodia authored Jun 9, 2020
1 parent c6bba80 commit dc574b7
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 190 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
- `pm.Data` container can now be used as input for other random variables (issue [#3842](https://github.com/pymc-devs/pymc3/issues/3842), fixed by [#3925](https://github.com/pymc-devs/pymc3/pull/3925)).
- Plots and Stats API sections now link to ArviZ documentation [#3927](https://github.com/pymc-devs/pymc3/pull/3927)
- Add `SamplerReport` with properties `n_draws`, `t_sampling` and `n_tune` to SMC. `n_tune` is always 0 [#3931](https://github.com/pymc-devs/pymc3/issues/3931).
- SMC-ABC: add option to define summary statistics, allow to sample from more complex models, remove redundant distances [#3940](https://github.com/pymc-devs/pymc3/issues/3940)

### Maintenance
- Tuning results no longer leak into sequentially sampled `Metropolis` chains (see #3733 and #3796).
Expand Down
322 changes: 183 additions & 139 deletions docs/source/notebooks/SMC-ABC_Lotka-Volterra_example.ipynb

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion pymc3/distributions/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@


class Simulator(NoDistribution):
def __init__(self, function, *args, **kwargs):
def __init__(self, function, *args, params=None, **kwargs):
"""
This class stores a function defined by the user in python language.
function: function
Simulation function defined by the user.
params: list
Parameters passed to function.
*args and **kwargs:
Arguments and keywords arguments that the function takes.
"""

self.function = function
self.params = params
observed = self.data
super().__init__(shape=np.prod(observed.shape), dtype=observed.dtype, *args, **kwargs)

Expand Down
13 changes: 6 additions & 7 deletions pymc3/smc/sample_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def sample_smc(
p_acc_rate=0.99,
threshold=0.5,
epsilon=1.0,
dist_func="absolute_error",
sum_stat=False,
dist_func="gaussian_kernel",
sum_stat="identity",
progressbar=False,
model=None,
random_seed=-1,
Expand Down Expand Up @@ -71,11 +71,10 @@ def sample_smc(
epsilon: float
Standard deviation of the gaussian pseudo likelihood. Only works with `kernel = ABC`
dist_func: str
Distance function. Available options are ``absolute_error`` (default) and
``sum_of_squared_distance``. Only works with ``kernel = ABC``
sum_stat: bool
Whether to use or not a summary statistics. Defaults to False. Only works with
``kernel = ABC``
Distance function. The only available option is ``gaussian_kernel``
sum_stat: str or callable
Summary statistics. Available options are ``indentity``, ``sorted``, ``mean``, ``median``.
If a callable is based it should return a number or a 1d numpy array.
progressbar: bool
Flag for displaying a progress bar. Defaults to False.
model: Model (optional if in ``with`` context)).
Expand Down
87 changes: 48 additions & 39 deletions pymc3/smc/smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from ..step_methods.metropolis import MultivariateNormalProposal
from ..backends.ndarray import NDArray
from ..backends.base import MultiTrace
from ..util import is_transformed_name

EXPERIMENTAL_WARNING = (
"Warning: SMC-ABC methods are experimental step methods and not yet"
Expand All @@ -53,7 +52,7 @@ def __init__(
threshold=0.5,
epsilon=1.0,
dist_func="absolute_error",
sum_stat=False,
sum_stat="Identity",
progressbar=False,
model=None,
random_seed=-1,
Expand Down Expand Up @@ -140,6 +139,7 @@ def setup_kernel(self):
self.epsilon,
simulator.observations,
simulator.distribution.function,
[v.name for v in simulator.distribution.params],
self.model,
self.var_info,
self.variables,
Expand Down Expand Up @@ -281,7 +281,7 @@ def mutate(self):
self.priors[draw],
self.likelihoods[draw],
draw,
*parameters
*parameters,
)
for draw in iterator
]
Expand All @@ -307,7 +307,7 @@ def posterior_to_trace(self):
size = 0
for var in varnames:
shape, new_size = self.var_info[var]
value.append(self.posterior[i][size: size + new_size].reshape(shape))
value.append(self.posterior[i][size : size + new_size].reshape(shape))
size += new_size
strace.record({k: v for k, v in zip(varnames, value)})
return MultiTrace([strace])
Expand Down Expand Up @@ -389,7 +389,16 @@ class PseudoLikelihood:
"""

def __init__(
self, epsilon, observations, function, model, var_info, variables, distance, sum_stat
self,
epsilon,
observations,
function,
params,
model,
var_info,
variables,
distance,
sum_stat,
):
"""
epsilon: float
Expand All @@ -398,34 +407,48 @@ def __init__(
observed data
function: python function
data simulator
params: list
names of the variables parameterizing the simulator.
model: PyMC3 model
var_info: dict
generated by ``SMC.initialize_population``
distance: str
Distance function. Available options are ``absolute_error`` (default) and
``sum_of_squared_distance``.
sum_stat: bool
Whether to use or not a summary statistics.
distance : str or callable
Distance function. The only available option is ``gaussian_kernel``
sum_stat: str or callable
Summary statistics. Available options are ``indentity``, ``sorted``, ``mean``,
``median``. The user can pass any valid Python function
"""
self.epsilon = epsilon
self.observations = observations
self.function = function
self.params = params
self.model = model
self.var_info = var_info
self.variables = variables
self.varnames = [v.name for v in self.variables]
self.unobserved_RVs = [v.name for v in self.model.unobserved_RVs]
self.kernel = self.gauss_kernel
self.dist_func = distance
self.sum_stat = sum_stat
self.get_unobserved_fn = self.model.fastfn(self.model.unobserved_RVs)

if distance == "absolute_error":
self.dist_func = self.absolute_error
elif distance == "sum_of_squared_distance":
self.dist_func = self.sum_of_squared_distance
if sum_stat == "identity":
self.sum_stat = lambda x: x
elif sum_stat == "sorted":
self.sum_stat = np.sort
elif sum_stat == "mean":
self.sum_stat = np.mean
elif sum_stat == "median":
self.sum_stat = np.median
elif hasattr(sum_stat, "__call__"):
self.sum_stat = sum_stat
else:
raise ValueError(f"The summary statistics {sum_stat} is not implemented")

self.observations = self.sum_stat(observations)

if distance == "gaussian_kernel":
self.distance = self.gaussian_kernel
elif hasattr(distance, "__call__"):
self.distance = distance
else:
raise ValueError("Distance metric not understood")
raise ValueError(f"The distance metric {distance} is not implemented")

def posterior_to_function(self, posterior):
model = self.model
Expand All @@ -436,32 +459,18 @@ def posterior_to_function(self, posterior):
size = 0
for var in self.variables:
shape, new_size = var_info[var.name]
varvalues.append(posterior[size: size + new_size].reshape(shape))
varvalues.append(posterior[size : size + new_size].reshape(shape))
size += new_size
point = {k: v for k, v in zip(self.varnames, varvalues)}
for varname, value in zip(self.unobserved_RVs, self.get_unobserved_fn(point)):
if not is_transformed_name(varname):
if varname in self.params:
samples[varname] = value
return samples

def gauss_kernel(self, value):
epsilon = self.epsilon
return (-(value ** 2) / epsilon ** 2 + np.log(1 / (2 * np.pi * epsilon ** 2))) / 2.0

def absolute_error(self, a, b):
if self.sum_stat:
return np.abs(a.mean() - b.mean())
else:
return np.mean(np.atleast_2d(np.abs(a - b)))

def sum_of_squared_distance(self, a, b):
if self.sum_stat:
return np.sum(np.atleast_2d((a.mean() - b.mean()) ** 2))
else:
return np.mean(np.sum(np.atleast_2d((a - b) ** 2)))
def gaussian_kernel(self, obs_data, sim_data):
return np.sum(-0.5 * ((obs_data - sim_data) / self.epsilon) ** 2)

def __call__(self, posterior):
func_parameters = self.posterior_to_function(posterior)
sim_data = self.function(**func_parameters)
value = self.dist_func(self.observations, sim_data)
return self.kernel(value)
sim_data = self.sum_stat(self.function(**func_parameters))
return self.distance(self.observations, sim_data)
8 changes: 4 additions & 4 deletions pymc3/tests/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,19 @@ def test_start(self):
class TestSMCABC(SeededTest):
def setup_class(self):
super().setup_class()
self.data = np.sort(np.random.normal(loc=0, scale=1, size=1000))
self.data = np.random.normal(loc=0, scale=1, size=1000)

def normal_sim(a, b):
return np.sort(np.random.normal(a, b, 1000))
return np.random.normal(a, b, 1000)

with pm.Model() as self.SMABC_test:
a = pm.Normal("a", mu=0, sd=5)
b = pm.HalfNormal("b", sd=2)
s = pm.Simulator("s", normal_sim, observed=self.data)
s = pm.Simulator("s", normal_sim, params=(a, b), observed=self.data)

def test_one_gaussian(self):
with self.SMABC_test:
trace = pm.sample_smc(draws=2000, kernel="ABC", epsilon=0.1)
trace = pm.sample_smc(draws=1000, kernel="ABC", sum_stat="sorted", epsilon=1)

np.testing.assert_almost_equal(self.data.mean(), trace["a"].mean(), decimal=2)
np.testing.assert_almost_equal(self.data.std(), trace["b"].mean(), decimal=1)

0 comments on commit dc574b7

Please sign in to comment.