Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include observed_data and sample_stats in inferencedata returned from sample_numpyro_nuts #5121

Closed
GMCobraz opened this issue Oct 31, 2021 · 2 comments

Comments

@GMCobraz
Copy link

Dear PyMC developers,

I notice the trace output difference between ordinary sampling and jax_sampling.

I notice that with return_inferencedata=True in ordinary sampling, I will get the following in trace:
posterior
log_likelihood
sample_stats
observed_data

But in jax_sampling, I only get posterior, which I cannot proceed with plot_ppc as I do not have observed_data.
May I know how to resolve this issue? Or is there any setup I can do to resolve this issue?

thanks

Below is my reproducible code:
import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc3 as pm
import pymc3.sampling_jax
import numpyro
import theano
numpyro.util.set_platform('cpu')

print(f"Running on PyMC3 v{pm.version}")

import numpy as np
import pandas as pd
import datetime as dt
#from pandas_datareader import data
import matplotlib.pyplot as plt

%matplotlib inline

returns = pd.read_csv(pm.get_data("SP500.csv"), index_col="Date")
returns["change"] = np.log(returns["Close"]).diff()
returns = returns.dropna()
n=len(returns)
returns.head()

Wiggins

def WigginsDGP(volatility_mu, volatility_theta, volatility_sigma):

wiggins = []
volatility = [volatility_mu]
for _ in range(n):
    v = volatility[-1] + volatility_theta * (volatility_mu - volatility[-1]) \
        + np.random.normal(0., volatility_sigma)

    volatility.append(v)
    r = np.random.normal(0., np.exp(v))

    wiggins.append(r)

return np.array(wiggins)

def sde(x, theta, mu, sigma):
return theta * (mu - x), sigma

%%time
import pymc3.distributions.timeseries as ts
with pm.Model() as wiggins_model:

volatility_theta = pm.Uniform('volatility_theta', lower=0., upper=1., testval=0.5)
volatility_mu = pm.Normal('volatility_mu', mu=-5., sd=.1, testval=-5)
volatility_sigma = pm.Uniform('volatility_sigma', lower=0.001, upper=0.09, testval=0.05)

#sde = lambda x, theta, mu, sigma: (theta * (mu - x), sigma)

volatility = ts.EulerMaruyama('volatility',
                              1.0,
                              sde,
                              (volatility_theta, volatility_mu, volatility_sigma),
                              shape=len(returns),
                              testval=np.ones_like(returns["change"]))

pm.Normal('obs', mu=0., sd=pm.math.exp(volatility), observed=returns["change"].values)

trace = pm.sample(4000, cores=8, chains=2, tune=3000, random_seed=42, return_inferencedata=True)
#trace = pm.sampling_jax.sample_numpyro_nuts(4000, chains=2, tune=3000, random_seed=42)

Versions and main components

  • PyMC/PyMC3 Version: 3.11.2
  • Aesara/Theano Version: 1.1.2
  • Python Version: 3.7.11
  • Operating system: Ubuntu 20.04
  • How did you install PyMC/PyMC3: conda
@ricardoV94 ricardoV94 changed the title sample_numpyro_nuts return only posterior Include observed_data and sample_stats in inferencedata returned from sample_numpyro_nuts Nov 3, 2021
@ricardoV94 ricardoV94 changed the title Include observed_data and sample_stats in inferencedata returned from sample_numpyro_nuts Include observed_data and sample_stats in inferencedata returned from sample_numpyro_nuts Nov 3, 2021
@ricardoV94
Copy link
Member

See also #5100

@OriolAbril
Copy link
Member

I don't know much about the internals of he numpyro sampling thing, but using az.from_numpyro instead of az.from_dict in https://github.com/pymc-devs/pymc/blob/main/pymc/sampling_jax.py#L185 might solve some of the issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants