Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Number of samples is high when init='nuts' #3861

Closed
rosgori opened this issue Mar 30, 2020 · 5 comments
Closed

Number of samples is high when init='nuts' #3861

rosgori opened this issue Mar 30, 2020 · 5 comments

Comments

@rosgori
Copy link

rosgori commented Mar 30, 2020

Description of your problem

Please provide a minimal, self-contained, and reproducible example.

When you use the option init='nuts' in pm.sample(), the number of samples is really high

import pymc3 as pm
import numpy as np

N = 1000
x = np.random.normal(loc=0, scale=2, size=N)

with pm.Model() as model:
    
    mu = pm.Normal('mu', 0, 5)
    sigma = pm.HalfNormal('sigma', sd=5)
    
    obs = pm.Normal('obs', mu=mu, sigma=sigma, observed=x)
    
with model:
    trace = pm.sample(draws=100, tune=100, init='nuts')
    
print(pm.summary(trace))

Please provide the full traceback.

Sampling 4 chains, 0 divergences:   1%|| 15137/1200000 [00:05<07:25, 2661.65draws/s]

As you see, the number 1200000 always appear, it doesn't matter if you change draws or tune in pm.sample.

Should that number change when you change draws and tune?

Versions and main components

  • PyMC3 Version: 3.8
  • Theano Version: 1.0.4
  • Python Version: 3.7.6
  • Operating system: Ubuntu 16.04.6 LTS
  • How did you install PyMC3: (conda/pip) conda
@twiecki
Copy link
Member

twiecki commented Mar 30, 2020

Thanks, that init method should just be dropped. Where did you find it referenced?

@AlexAndorra
Copy link
Contributor

It seems to be in pm.sample doc:

nuts: Run NUTS and estimate posterior mean and mass matrix from the trace.

Then, in the source code, when the init method is "nuts", it samples with n_init as the number of draws:

 elif init == "nuts":
      init_trace = pm.sample(
          draws=n_init, step=pm.NUTS(), tune=n_init // 2, random_seed=random_seed
      )

Since pm.sample defaults to n_init=200_000, this yields 200_000 + 100_000 samples per chain, which is, I think, where the 1_200_000 come from (since you're sampling 4 chains in the example).

@twiecki
Copy link
Member

twiecki commented Mar 30, 2020

Yeah this is a left-over, we have better ones now so should just remove it.

@AlexAndorra
Copy link
Contributor

Ok, seems straightforward -- I'll do it

@Ahanmr
Copy link
Contributor

Ahanmr commented Apr 1, 2020

@twiecki @AlexAndorra This issue is fixed right? Is there still anything pending on this one? Looks like init='nuts' is mostly dropped already.

@twiecki twiecki closed this as completed Apr 1, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants