Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert random variables to value variables so pm.sample(var_names) works correctly #7284

Merged
merged 1 commit into from
Apr 28, 2024

Conversation

tomicapretto
Copy link
Contributor

@tomicapretto tomicapretto commented Apr 27, 2024

Description

This PR converts the random variables to value variables when var_names is not None in pm.sample(). Before this PR, using var_names resulted in sampling from the prior. The problem is better explained in the linked issue.

Related Issue

Closes #7258

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7284.org.readthedocs.build/en/7284/

@tomicapretto
Copy link
Contributor Author

@ricardoV94 in order for this to match what is done in model.unobserved_value_vars I do

    # Get value variables for the trace
    if var_names is not None:
        value_vars = []
        transformed_rvs = []
        for rv in model.unobserved_RVs:
            if rv.name in var_names:
                value_var = model.rvs_to_values[rv]
                transform = model.rvs_to_transforms[rv]
                if transform is not None:
                    transformed_rvs.append(rv)
                value_vars.append(value_var)

        transformed_value_vars = model.replace_rvs_by_values(transformed_rvs)
        trace_vars = value_vars + transformed_value_vars
        assert len(trace_vars) == len(var_names), "Not all var_names were found in the model"

However, an assertion error is raised when there are transformed variables because they add two elements to the trace_vars list (i.e. kappa_log__ and kappa in the example shown in the issue).

Do you think the modification already pushed in the PR is the correct one, or do we need to somehow explicitly account for the transformations?

@tomicapretto
Copy link
Contributor Author

It seems the current state is doing the right thing. See the following example.

import arviz as az
import numpy as np
import pymc as pm

batch = np.array(
    [
        1,  1,  1,  1,  2,  2,  2,  3,  3,  3,  4,  4,  4,  4,  5,  5,  5,
        6,  6,  6,  7,  7,  7,  7,  8,  8,  8,  9,  9, 10, 10, 10
    ]
)
temp = np.array(
    [
        205, 275, 345, 407, 218, 273, 347, 212, 272, 340, 235, 300, 365,
        410, 307, 367, 395, 267, 360, 402, 235, 275, 358, 416, 285, 365,
        444, 351, 424, 365, 379, 428
    ]
)

y = np.array(
    [
        0.122, 0.223, 0.347, 0.457, 0.08 , 0.131, 0.266, 0.074, 0.182,
        0.304, 0.069, 0.152, 0.26 , 0.336, 0.144, 0.268, 0.349, 0.1  ,
        0.248, 0.317, 0.028, 0.064, 0.161, 0.278, 0.05 , 0.176, 0.321,
        0.14 , 0.232, 0.085, 0.147, 0.18
    ]
)

batch_values, batch_idx  = np.unique(batch, return_inverse=True)

coords = {
    "batch": batch_values
}

with pm.Model(coords=coords) as model:
    b_batch = pm.Normal("b_batch", dims="batch")
    b_temp = pm.Normal("b_temp")
    mu = pm.Deterministic("mu", pm.math.invlogit(b_batch[batch_idx] + b_temp * temp))
    kappa = pm.Gamma("kappa", alpha=2, beta=2)
    
    alpha = mu * kappa
    beta = (1 - mu) * kappa
    
    pm.Beta("y", alpha=alpha, beta=beta, observed=y)

with model:
    idata_1 = pm.sample(random_seed=1234)
    idata_2 = pm.sample(var_names=["b_batch", "b_temp", "kappa"], random_seed=1234)

az.plot_forest([idata_1, idata_2], var_names=["b_batch"])
az.plot_forest([idata_1, idata_2], var_names=["b_temp"])
az.plot_forest([idata_1, idata_2], var_names=["kappa"])

image
image
image

@codecov-commenter
Copy link

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.33%. Comparing base (60a6314) to head (281fb5d).

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7284      +/-   ##
==========================================
+ Coverage   91.67%   92.33%   +0.65%     
==========================================
  Files         102      102              
  Lines       17017    17018       +1     
==========================================
+ Hits        15600    15713     +113     
+ Misses       1417     1305     -112     
Files Coverage Δ
pymc/sampling/mcmc.py 87.74% <100.00%> (+0.46%) ⬆️

... and 3 files with indirect coverage changes

Copy link
Member

@fonnesbeck fonnesbeck left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

@fonnesbeck fonnesbeck merged commit a74c03f into pymc-devs:main Apr 28, 2024
22 checks passed
@ricardoV94
Copy link
Member

This one warranted a regression test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unexpected behavior with pm.sample(var_names=...)
4 participants