Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LKJCorr and LKJCholeskyCov refactor #5382

Merged
merged 10 commits into from
Jan 28, 2022

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jan 24, 2022

This supersedes #4784

Tests are currently failing due to aesara-devs/aesara#786

With local patch, they pass. Also now the random method of LKJCorr works properly for arbitrary sizes, even though the logp method is restricted to 2D values due to reliance on matrix_pos_def

TODO:

Closes #4686

@ricardoV94 ricardoV94 requested a review from kc611 January 24, 2022 13:01
@ricardoV94 ricardoV94 changed the title Lkj corr refactor LKJcorr refactor Jan 24, 2022
@codecov
Copy link

codecov bot commented Jan 24, 2022

Codecov Report

Merging #5382 (8666bbc) into main (ba83d28) will increase coverage by 0.95%.
The diff coverage is 94.08%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5382      +/-   ##
==========================================
+ Coverage   80.43%   81.39%   +0.95%     
==========================================
  Files          82       82              
  Lines       14159    14213      +54     
==========================================
+ Hits        11389    11568     +179     
+ Misses       2770     2645     -125     
Impacted Files Coverage Δ
pymc/distributions/multivariate.py 91.52% <93.92%> (+15.51%) ⬆️
pymc/distributions/transforms.py 100.00% <100.00%> (+7.44%) ⬆️

@ricardoV94
Copy link
Member Author

CC @tomicapretto @danhphan

beta -= 0.5
y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=size, random_state=rng)
z = stats.norm.rvs(loc=0, scale=1, size=(size, mp1), random_state=rng)
z = z / np.sqrt(np.einsum("ij,ij->i", z, z))[..., np.newaxis]
Copy link
Member Author

@ricardoV94 ricardoV94 Jan 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tomicapretto, continuing from #4784 (comment) here is where I think the not yet refactored LKJCholeskyCov random method was wrong. My understanding is that it was trying to generalize the pre-existing code in LJKCorr (removed in this PR) to allow for more flexible sizes, but in doing so altered the meaning of this einsum. The still "buggy" code there for reference is this:

z = z / np.sqrt(np.einsum("ij,ij->j", z, z))

CC @lucianopaz, I think you wrote this code (for the cholesky) originally in #3293, do you have a chance to ping in?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember that PR being centered around making a mixture distribution of MvNormals work, and to be able to sample from their prior. The flexible size stuff came from over there. I hope that I did not mess up the einsum back then, but I honestly don't remember why I had written "j" instead of "i", and I don't remember the algorithm of the rng at all.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I used this new more flexible logic for the LKJCorrRV, the random test failed, and that seems to be the stricter test for the rng that we have.

Then I started debugging line by line, and the einsum index was what changed the results between the old logic in LKJCorrRV and the more flexible one in LKJCholeskyRV.

@ricardoV94 ricardoV94 changed the title LKJcorr refactor LKJCorr refactor Jan 24, 2022
@ricardoV94 ricardoV94 force-pushed the lkj_corr_refactor branch 2 times, most recently from 2bbec95 to db62693 Compare January 25, 2022 12:07
@ricardoV94 ricardoV94 changed the title LKJCorr refactor LKJCorr and LKJCholeskyCov refactor Jan 25, 2022
@ricardoV94
Copy link
Member Author

LKJCholeskyCov is also refactored!

@ricardoV94 ricardoV94 force-pushed the lkj_corr_refactor branch 5 times, most recently from 841e825 to db0b762 Compare January 25, 2022 18:31
@ricardoV94 ricardoV94 marked this pull request as ready for review January 26, 2022 06:34
@ricardoV94
Copy link
Member Author

Tests are passing!

@ricardoV94 ricardoV94 mentioned this pull request Jan 26, 2022
26 tasks
@ricardoV94 ricardoV94 force-pushed the lkj_corr_refactor branch 3 times, most recently from 75bcb20 to a353584 Compare January 28, 2022 10:26
Copy link
Contributor

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great. I don't know what to say about the einsum though.

pymc/distributions/multivariate.py Outdated Show resolved Hide resolved
pymc/distributions/multivariate.py Show resolved Hide resolved
beta -= 0.5
y = stats.beta.rvs(a=mp1 / 2.0, b=beta, size=size, random_state=rng)
z = stats.norm.rvs(loc=0, scale=1, size=(size, mp1), random_state=rng)
z = z / np.sqrt(np.einsum("ij,ij->i", z, z))[..., np.newaxis]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember that PR being centered around making a mixture distribution of MvNormals work, and to be able to sample from their prior. The flexible size stuff came from over there. I hope that I did not mess up the einsum back then, but I honestly don't remember why I had written "j" instead of "i", and I don't remember the algorithm of the rng at all.

Changes:
* compute_corr now defaults to True
* LKJCholeskyCov now also provides a `.dist` interface
@ricardoV94 ricardoV94 merged commit 0dca647 into pymc-devs:main Jan 28, 2022
@twiecki
Copy link
Member

twiecki commented Jan 30, 2022

🥳

@ricardoV94 ricardoV94 deleted the lkj_corr_refactor branch January 31, 2022 09:23
@martiningram
Copy link
Contributor

This is really amazing work! I wanted to give it a go and thought I'd try the LKJ example notebook. However that didn't seem to work unfortunately! I messed around with things a little bit but couldn't get the sampling to work. Sorry, it's entirely possible that my python environment isn't exactly right -- I did upgrade to the aesara and aeppl versions listed here, but maybe I missed something. In any case, it might be worth trying the example notebook @ricardoV94 !

@ricardoV94
Copy link
Member Author

@martiningram I haven't tried to run the notebook but I see the code has a slight issue, relative to V4. The sd_dist should have the same shape as n (i.e., 2), but it is a scalar in the notebook.

Ill try and run the notebook some time soon

@ricardoV94
Copy link
Member Author

Forgot to ask, what error are you seeing?

@martiningram
Copy link
Contributor

The first is in cell 7, following:

packed_L.tag.test_value.shape

That gives:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/tmp/ipykernel_30498/1180155108.py in <module>
----> 1 packed_L.tag.test_value.shape

AttributeError: 'tuple' object has no attribute 'tag'

If I skip that line and the next, running

coords = {"axis": ["y", "z"], "axis_bis": ["y", "z"], "obs_id": np.arange(N)}
with pm.Model(coords=coords) as model:
    chol, corr, stds = pm.LKJCholeskyCov(
        "chol", n=2, eta=2.0, sd_dist=pm.Exponential.dist(1.0), compute_corr=True
    )
    cov = pm.Deterministic("cov", chol.dot(chol.T), dims=("axis", "axis_bis"))

seems to work, and the next cell also, but sampling gives:

Auto-assigning NUTS sampler...
Initializing NUTS using adapt_diag...

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
~/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/link/vm.py in __call__(self)
    308                 ):
--> 309                     thunk()
    310                     for old_s in old_storage:

~/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/graph/op.py in rval(p, i, o, n)
    507             def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
--> 508                 r = p(n, [x[0] for x in i], o)
    509                 for o in node.outputs:

~/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/tensor/random/op.py in perform(self, node, inputs, outputs)
    381 
--> 382         smpl_val = self.rng_fn(rng, *(args + [size]))
    383 

~/projects/pymc3_vs_stan/pymc/pymc/distributions/multivariate.py in rng_fn(self, rng, n, eta, D, size)
   1149 
-> 1150         D = D.reshape(flat_size, n)
   1151         C *= D[..., :, np.newaxis] * D[..., np.newaxis, :]

ValueError: cannot reshape array of size 1 into shape (1,2)

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
/tmp/ipykernel_30498/1518530166.py in <module>
      1 with model:
----> 2     trace = pm.sample(
      3         random_seed=RANDOM_SEED,
      4         init="adapt_diag",
      5         return_inferencedata=True,

~/projects/pymc3_vs_stan/pymc/pymc/sampling.py in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
    494             # By default, try to use NUTS
    495             _log.info("Auto-assigning NUTS sampler...")
--> 496             initial_points, step = init_nuts(
    497                 init=init,
    498                 chains=chains,

~/projects/pymc3_vs_stan/pymc/pymc/sampling.py in init_nuts(init, chains, n_init, model, seeds, progressbar, jitter_max_retries, tune, initvals, **kwargs)
   2318     ]
   2319 
-> 2320     initial_points = _init_jitter(
   2321         model,
   2322         initvals,

~/projects/pymc3_vs_stan/pymc/pymc/sampling.py in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries)
   2195 
   2196     if not jitter:
-> 2197         return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)]
   2198 
   2199     initial_points = []

~/projects/pymc3_vs_stan/pymc/pymc/sampling.py in <listcomp>(.0)
   2195 
   2196     if not jitter:
-> 2197         return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)]
   2198 
   2199     initial_points = []

~/projects/pymc3_vs_stan/pymc/pymc/initial_point.py in inner(seed, *args, **kwargs)
    214                     new_rng = np.random.Generator(seed)
    215                 rng.set_value(new_rng, True)
--> 216             values = func(*args, **kwargs)
    217             return dict(zip(varnames, values))
    218 

~/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/compile/function/types.py in __call__(self, *args, **kwargs)
    967         try:
    968             outputs = (
--> 969                 self.fn()
    970                 if output_subset is None
    971                 else self.fn(output_subset=output_subset)

~/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/link/vm.py in __call__(self)
    311                         old_s[0] = None
    312             except Exception:
--> 313                 raise_with_op(self.fgraph, node, thunk)
    314 
    315 

~/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/link/utils.py in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    523         # Some exception need extra parameter in inputs. So forget the
    524         # extra long error message in that case.
--> 525     raise exc_value.with_traceback(exc_trace)
    526 
    527 

~/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/link/vm.py in __call__(self)
    307                     self.thunks, self.nodes, self.post_thunk_clear
    308                 ):
--> 309                     thunk()
    310                     for old_s in old_storage:
    311                         old_s[0] = None

~/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/graph/op.py in rval(p, i, o, n)
    506             # default arguments are stored in the closure of `rval`
    507             def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
--> 508                 r = p(n, [x[0] for x in i], o)
    509                 for o in node.outputs:
    510                     compute_map[o][0] = True

~/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/tensor/random/op.py in perform(self, node, inputs, outputs)
    380         rng_var_out[0] = rng
    381 
--> 382         smpl_val = self.rng_fn(rng, *(args + [size]))
    383 
    384         if (

~/projects/pymc3_vs_stan/pymc/pymc/distributions/multivariate.py in rng_fn(self, rng, n, eta, D, size)
   1148         C = LKJCorrRV._random_corr_matrix(rng, n, eta, flat_size)
   1149 
-> 1150         D = D.reshape(flat_size, n)
   1151         C *= D[..., :, np.newaxis] * D[..., np.newaxis, :]
   1152 

ValueError: cannot reshape array of size 1 into shape (1,2)
Apply node that caused the error: _lkjcholeskycov_rv{1, (0, 0, 1), floatX, False}(RandomStateSharedVariable(<RandomState(PCG64) at 0x7F3D4D447E40>), TensorConstant{[]}, TensorConstant{11}, TensorConstant{2}, TensorConstant{2.0}, exponential_rv{0, (0,), floatX, False}.out)
Toposort index: 1
Inputs types: [RandomStateType, TensorType(int64, (0,)), TensorType(int64, ()), TensorType(int32, ()), TensorType(float64, ()), TensorType(float64, ())]
Inputs shapes: ['No shapes', (0,), (), (), (), ()]
Inputs strides: ['No strides', (8,), (), (), (), ()]
Inputs values: [RandomState(PCG64) at 0x7F3D4D447E40, array([], dtype=int64), array(11), array(2, dtype=int32), array(2.), array(2.01457564)]
Outputs clients: [[], [Elemwise{second,no_inplace}(chol, TensorConstant{(1,) of 0.0})]]

Backtrace when the node is created (use Aesara flag traceback__limit=N to make it longer):
  File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/distributions/multivariate.py", line 1181, in __new__
    return super().__new__(cls, name, eta, n, sd_dist, **kwargs)
  File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/distributions/distribution.py", line 266, in __new__
    rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape(
  File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/distributions/distribution.py", line 165, in _make_rv_and_resize_shape
    rv_out = cls.dist(*args, **kwargs)
  File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/distributions/multivariate.py", line 1203, in dist
    return super().dist([n, eta, sd_dist], size=size, **kwargs)
  File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/distributions/distribution.py", line 353, in dist
    rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
  File "/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/tensor/random/op.py", line 293, in __call__
    res = super().__call__(rng, size, dtype, *args, **kwargs)
  File "/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aesara/graph/op.py", line 283, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/distributions/multivariate.py", line 1134, in make_node
    return super().make_node(rng, size, dtype, n, eta, D)

HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

In case it's helpful, I also modified the notebook somewhat to remove the failing lines and to try to implement your suggestion of giving the sd_dist a shape of 2 here. That looked good first, but sampling fails:

Auto-assigning NUTS sampler...
Initializing NUTS using adapt_diag...
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aeppl/joint_logprob.py:161: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FC4AB8F8E40>), TensorConstant{[]}, TensorConstant{11}, BroadcastTo.0, BroadcastTo.0)
  warnings.warn(
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aeppl/joint_logprob.py:161: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FC4AB8F8E40>), TensorConstant{[]}, TensorConstant{11}, BroadcastTo.0, BroadcastTo.0)
  warnings.warn(
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aeppl/joint_logprob.py:161: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FC4AA9BD340>), TensorConstant{[]}, TensorConstant{11}, BroadcastTo.0, BroadcastTo.0)
  warnings.warn(
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aeppl/joint_logprob.py:161: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FC4AA9BD540>), TensorConstant{[]}, TensorConstant{11}, BroadcastTo.0, BroadcastTo.0)
  warnings.warn(
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aeppl/joint_logprob.py:161: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FC4AA9BD640>), TensorConstant{[]}, TensorConstant{11}, BroadcastTo.0, BroadcastTo.0)
  warnings.warn(
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [chol, μ]

4.92% [394/8000 00:00<00:17 Sampling 4 chains, 0 divergences]

/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/quadpotential.py:258: RuntimeWarning: divide by zero encountered in true_divide
  np.divide(1, self._stds, out=self._inv_stds)
/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/quadpotential.py:237: RuntimeWarning: invalid value encountered in multiply
  return np.multiply(self._var, x, out=out)
/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/quadpotential.py:258: RuntimeWarning: divide by zero encountered in true_divide
  np.divide(1, self._stds, out=self._inv_stds)
/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/quadpotential.py:237: RuntimeWarning: invalid value encountered in multiply
  return np.multiply(self._var, x, out=out)
/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/quadpotential.py:258: RuntimeWarning: divide by zero encountered in true_divide
  np.divide(1, self._stds, out=self._inv_stds)
/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/quadpotential.py:237: RuntimeWarning: invalid value encountered in multiply
  return np.multiply(self._var, x, out=out)
/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/quadpotential.py:258: RuntimeWarning: divide by zero encountered in true_divide
  np.divide(1, self._stds, out=self._inv_stds)
/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/quadpotential.py:237: RuntimeWarning: invalid value encountered in multiply
  return np.multiply(self._var, x, out=out)
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aeppl/joint_logprob.py:161: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FC4AA9BD340>), TensorConstant{[]}, TensorConstant{11}, BroadcastTo.0, BroadcastTo.0)
  warnings.warn(
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aeppl/joint_logprob.py:161: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FC4AA9BD340>), TensorConstant{[]}, TensorConstant{11}, BroadcastTo.0, BroadcastTo.0)
  warnings.warn(
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aeppl/joint_logprob.py:161: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FC4AA9BD340>), TensorConstant{[]}, TensorConstant{11}, BroadcastTo.0, BroadcastTo.0)
  warnings.warn(
/home/martin/miniconda3/envs/dl/lib/python3.8/site-packages/aeppl/joint_logprob.py:161: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: normal_rv{0, (0, 0), floatX, False}(RandomStateSharedVariable(<RandomState(MT19937) at 0x7FC4AA9BD340>), TensorConstant{[]}, TensorConstant{11}, BroadcastTo.0, BroadcastTo.0)
  warnings.warn(

---------------------------------------------------------------------------
RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/parallel_sampling.py", line 125, in run
    self._start_loop()
  File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/parallel_sampling.py", line 178, in _start_loop
    point, stats = self._compute_point()
  File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/parallel_sampling.py", line 203, in _compute_point
    point, stats = self._step_method.step(self._point)
  File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/arraystep.py", line 286, in step
    return super().step(point)
  File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/arraystep.py", line 208, in step
    step_res = self.astep(q)
  File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/base_hmc.py", line 164, in astep
    self.potential.raise_ok(q0.point_map_info)
  File "/home/martin/projects/pymc3_vs_stan/pymc/pymc/step_methods/hmc/quadpotential.py", line 308, in raise_ok
    raise ValueError("\n".join(errmsg))
ValueError: Mass matrix contains zeros on the diagonal. 
The derivative of RV `chol_cholesky-cov-packed__`.ravel()[[0 1 2]] is zero.
The derivative of RV `μ`.ravel()[[0 1]] is zero.
"""

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
ValueError: Mass matrix contains zeros on the diagonal. 
The derivative of RV `chol_cholesky-cov-packed__`.ravel()[[0 1 2]] is zero.
The derivative of RV `μ`.ravel()[[0 1]] is zero.

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_30650/2034248556.py in <module>
      1 with model:
----> 2     trace = pm.sample(
      3         random_seed=RANDOM_SEED,
      4         init="adapt_diag",
      5     )

~/projects/pymc3_vs_stan/pymc/pymc/sampling.py in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
    566         _print_step_hierarchy(step)
    567         try:
--> 568             trace = _mp_sample(**sample_args, **parallel_args)
    569         except pickle.PickleError:
    570             _log.warning("Could not pickle model, sampling singlethreaded.")

~/projects/pymc3_vs_stan/pymc/pymc/sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, **kwargs)
   1483         try:
   1484             with sampler:
-> 1485                 for draw in sampler:
   1486                     trace = traces[draw.chain - chain]
   1487                     if trace.supports_sampler_stats and draw.stats is not None:

~/projects/pymc3_vs_stan/pymc/pymc/parallel_sampling.py in __iter__(self)
    458 
    459         while self._active:
--> 460             draw = ProcessAdapter.recv_draw(self._active)
    461             proc, is_last, draw, tuning, stats, warns = draw
    462             self._total_draws += 1

~/projects/pymc3_vs_stan/pymc/pymc/parallel_sampling.py in recv_draw(processes, timeout)
    347             else:
    348                 error = RuntimeError("Chain %s failed." % proc.chain)
--> 349             raise error from old_error
    350         elif msg[0] == "writing_done":
    351             proc._readable = True

RuntimeError: Chain 2 failed.

Hope some of this is helpful!

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 31, 2022

@martiningram, those tag.test_value have to be replaced by .eval(). The tuple error, is because the default behavior compute_corr changed. However after fixing these calls, and the shape of the sd_dist I still get the failed chains. There is some weird problem with the MvNormal logp (if you call model.point_logps(0) twice, you can different values).

Surprisingly this is not even related to the LKJCholeskyCov, but to μ xD. If you swap that for a constant value the logp is no longer stochastic. It may be related to these recent changes: https://github.com/pymc-devs/pymc/pull/5386/files

Anyway, thanks for bringing it up. There is definitely a bug lurking around

@martiningram
Copy link
Contributor

martiningram commented Jan 31, 2022

Oh weird, thanks for taking such a close look at this! Interesting that it's probably just due to μ...! 😃

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Port remaining distributions to v4
4 participants