Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use fastprogress instead of tqdm progressbar #3693

Merged
merged 10 commits into from
Dec 9, 2019
7 changes: 6 additions & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Release Notes

## PyMC3 3.8 (on deck)
## PyMC3 3.9 (On deck)

### New features
- use [fastprogress](https://github.com/fastai/fastprogress) instead of tqdm [#3693](https://github.com/pymc-devs/pymc3/pull/3693)

## PyMC3 3.8 (November 29 2019)

### New features
- Implemented robust u turn check in NUTS (similar to stan-dev/stan#2800). See PR [#3605]
Expand Down
2 changes: 1 addition & 1 deletion pymc3/parallel_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import errno

import numpy as np
from fastprogress import progress_bar

from . import theanof

Expand Down Expand Up @@ -348,7 +349,6 @@ def __init__(
start_chain_num=0,
progressbar=True,
):
from fastprogress import progress_bar

if any(len(arg) != chains for arg in [seeds, start_points]):
raise ValueError("Number of seeds and start_points must be %s." % chains)
Expand Down
36 changes: 27 additions & 9 deletions pymc3/variational/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def run_profiling(self, n=1000, score=None, **kwargs):
score = self._maybe_score(score)
fn_kwargs = kwargs.pop("fn_kwargs", dict())
fn_kwargs["profile"] = True
step_func = self.objective.step_function(score=score, fn_kwargs=fn_kwargs, **kwargs)
step_func = self.objective.step_function(
score=score, fn_kwargs=fn_kwargs, **kwargs
)
progress = progress_bar(range(n))
try:
for _ in progress:
Expand Down Expand Up @@ -555,12 +557,16 @@ def __init__(
random_seed=None,
estimator=KSD,
kernel=test_functions.rbf,
**kwargs
**kwargs,
):
if kwargs.get("local_rv") is not None:
raise opvi.AEVBInferenceError("SVGD does not support local groups")
empirical = Empirical(
size=n_particles, jitter=jitter, start=start, model=model, random_seed=random_seed,
size=n_particles,
jitter=jitter,
start=start,
model=model,
random_seed=random_seed,
)
super().__init__(approx=empirical, estimator=estimator, kernel=kernel, **kwargs)

Expand Down Expand Up @@ -626,14 +632,22 @@ def __init__(self, approx=None, estimator=KSD, kernel=test_functions.rbf, **kwar
)
super().__init__(estimator=estimator, approx=approx, kernel=kernel, **kwargs)

def fit(self, n=10000, score=None, callbacks=None, progressbar=True, obj_n_mc=500, **kwargs):
def fit(
self,
n=10000,
score=None,
callbacks=None,
progressbar=True,
obj_n_mc=500,
**kwargs,
):
return super().fit(
n=n,
score=score,
callbacks=callbacks,
progressbar=progressbar,
obj_n_mc=obj_n_mc,
**kwargs
**kwargs,
)

def run_profiling(self, n=1000, score=None, obj_n_mc=500, **kwargs):
Expand Down Expand Up @@ -703,7 +717,7 @@ def fit(
random_seed=None,
start=None,
inf_kwargs=None,
**kwargs
**kwargs,
):
r"""Handy shortcut for using inference methods in functional way

Expand Down Expand Up @@ -780,7 +794,9 @@ def fit(
inf_kwargs["start"] = start
if model is None:
model = pm.modelcontext(model)
_select = dict(advi=ADVI, fullrank_advi=FullRankADVI, svgd=SVGD, asvgd=ASVGD, nfvi=NFVI)
_select = dict(
advi=ADVI, fullrank_advi=FullRankADVI, svgd=SVGD, asvgd=ASVGD, nfvi=NFVI
)
if isinstance(method, str):
method = method.lower()
if method.startswith("nfvi="):
Expand All @@ -791,10 +807,12 @@ def fit(
inference = _select[method](model=model, **inf_kwargs)
else:
raise KeyError(
"method should be one of %s " "or Inference instance" % set(_select.keys())
f"method should be one of {set(_select.keys())} or Inference instance"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯

)
elif isinstance(method, Inference):
inference = method
else:
raise TypeError("method should be one of %s " "or Inference instance" % set(_select.keys()))
raise TypeError(
f"method should be one of {set(_select.keys())} or Inference instance"
)
return inference.fit(n, **kwargs)