Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: SMC does not respect initval #7438

Closed
tvwenger opened this issue Jul 29, 2024 · 1 comment
Closed

BUG: SMC does not respect initval #7438

tvwenger opened this issue Jul 29, 2024 · 1 comment
Labels

Comments

@tvwenger
Copy link
Contributor

Describe the issue:

Per the discussion here, SMC does not properly handle custom initval for free RVs. This is apparent when sampling from distributions with the Ordered() transformation, where it is imperative that the initval be ordered else the model will fail to sample due to a NaN logp at the initial point.

Reproduceable code example:

import pymc as pm
from pymc.distributions.transforms import Ordered

with pm.Model() as model:
    a = pm.Normal("a", mu=0.0, sigma=1.0, size=(3,), transform=Ordered(), initval=[-1.0, 0.0, 1.0])
    b = pm.Normal("b", mu=a, sigma=1.0, observed=[0.0, 0.0, 0.0])

with model:
    trace = pm.sample_smc()

Error message:

<details>
---------------------------------------------------------------------------
_RemoteTraceback                          Traceback (most recent call last)
_RemoteTraceback:
"""
Traceback (most recent call last):
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/concurrent/futures/process.py", line 256, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/smc/sampling.py", line 346, in _sample_smc_int
    smc.update_beta_and_weights()
  File "/home/twenger/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/smc/kernels.py", line 273, in update_beta_and_weights
    ESS = int(np.exp(-logsumexp(log_weights * 2)))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: cannot convert float NaN to integer
"""

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[1], line 9
      6     b = pm.Normal("b", mu=a, sigma=1.0, observed=[0.0, 0.0, 0.0])
      8 with model:
----> 9     trace = pm.sample_smc()

File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/smc/sampling.py:217, in sample_smc(draws, kernel, start, model, random_seed, chains, cores, compute_convergence_checks, return_inferencedata, idata_kwargs, progressbar, **kernel_kwargs)
    208 params = (
    209     draws,
    210     kernel,
    211     start,
    212     model,
    213 )
    215 t1 = time.time()
--> 217 results = run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores)
    219 (
    220     traces,
    221     sample_stats,
    222     sample_settings,
    223 ) = zip(*results)
    225 trace = MultiTrace(traces)

File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/smc/sampling.py:422, in run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores)
    415                 # update the progress bar for this task:
    416                 progress.update(
    417                     status=f"Stage: {stage} Beta: {beta:.3f}",
    418                     task_id=task_id,
    419                     refresh=True,
    420                 )
--> 422 return tuple(cloudpickle.loads(r.result()) for r in done)

File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/smc/sampling.py:422, in <genexpr>(.0)
    415                 # update the progress bar for this task:
    416                 progress.update(
    417                     status=f"Stage: {stage} Beta: {beta:.3f}",
    418                     task_id=task_id,
    419                     refresh=True,
    420                 )
--> 422 return tuple(cloudpickle.loads(r.result()) for r in done)

File ~/miniconda3/envs/pymc/lib/python3.11/concurrent/futures/_base.py:449, in Future.result(self, timeout)
    447     raise CancelledError()
    448 elif self._state == FINISHED:
--> 449     return self.__get_result()
    451 self._condition.wait(timeout)
    453 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:

File ~/miniconda3/envs/pymc/lib/python3.11/concurrent/futures/_base.py:401, in Future.__get_result(self)
    399 if self._exception:
    400     try:
--> 401         raise self._exception
    402     finally:
    403         # Break a reference cycle with the exception in self._exception
    404         self = None

ValueError: cannot convert float NaN to integer
</details>

PyMC version information:

pymc 5.16.2 (conda)

Context for the issue:

No response

@tvwenger
Copy link
Contributor Author

tvwenger commented Aug 1, 2024

Closed by #7439

@tvwenger tvwenger closed this as completed Aug 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant