Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Sparse matrix handling with ICAR prior #7406

Closed
jfhawkin opened this issue Jul 8, 2024 · 4 comments
Closed

ENH: Sparse matrix handling with ICAR prior #7406

jfhawkin opened this issue Jul 8, 2024 · 4 comments

Comments

@jfhawkin
Copy link

jfhawkin commented Jul 8, 2024

Before

@classmethod
    def dist(cls, W, sigma=1, zero_sum_stdev=0.001, **kwargs):
        # Note: These checks are forcing W to be non-symbolic
        if not W.ndim == 2:
            raise ValueError("W must be matrix with ndim=2")

        if not W.shape[0] == W.shape[1]:
            raise ValueError("W must be a square matrix")

        if not np.allclose(W.T, W):
            raise ValueError("W must be a symmetric matrix")

        if np.any((W != 0) & (W != 1)):
            raise ValueError("W must be composed of only 1s and 0s")

        W = pt.as_tensor_variable(W, dtype=int)
        sigma = pt.as_tensor_variable(sigma)
        zero_sum_stdev = pt.as_tensor_variable(zero_sum_stdev)
        return super().dist([W, sigma, zero_sum_stdev], **kwargs)



    def support_point(rv, size, W, sigma, zero_sum_stdev):
        N = pt.shape(W)[-2]
        return pt.zeros(N)

    def logp(value, W, sigma, zero_sum_stdev):
        # convert adjacency matrix to edgelist representation
        # An edgelist is a pair of lists.
        # If node i and node j are connected then one list
        # will contain i and the other will contain j at the same
        # index value.
        # We only use the lower triangle here because adjacency
        # is a undirected connection.
        N = pt.shape(W)[-2]
        node1, node2 = pt.eq(pt.tril(W), 1).nonzero()

        pairwise_difference = (-1 / (2 * sigma**2)) * pt.sum(pt.square(value[node1] - value[node2]))
        zero_sum = (
            -0.5 * pt.pow(pt.sum(value) / (zero_sum_stdev * N), 2)
            - pt.log(pt.sqrt(2.0 * np.pi))
            - pt.log(zero_sum_stdev * N)
        )

        return check_parameters(pairwise_difference + zero_sum, sigma > 0, msg="sigma > 0")

After

@classmethod
    def dist(cls, W, sigma=1, zero_sum_stdev=0.001, **kwargs):
        # Note: These checks are forcing W to be non-symbolic
        if not W.ndim == 2:
            raise ValueError("W must be matrix with ndim=2")

        if not W.shape[0] == W.shape[1]:
            raise ValueError("W must be a square matrix")

        if not np.allclose(W.data.T, W.data):
            raise ValueError("W must be a symmetric matrix")

        if np.any((W.data != 0) & (W.data != 1)):
            raise ValueError("W must be composed of only 1s and 0s")

        sigma = pt.as_tensor_variable(sigma)
        zero_sum_stdev = pt.as_tensor_variable(zero_sum_stdev)
        return super().dist([W, sigma, zero_sum_stdev], **kwargs)



    def support_point(rv, size, W, sigma, zero_sum_stdev):
        N = pt.shape(W)[-2]
        return pt.zeros(N)

    def logp(value, W, sigma, zero_sum_stdev):
        # convert adjacency matrix to edgelist representation
        # An edgelist is a pair of lists.
        # If node i and node j are connected then one list
        # will contain i and the other will contain j at the same
        # index value.
        # We only use the lower triangle here because adjacency
        # is a undirected connection.
        N = pt.shape(W)[-2]
        node1, node2 = W.nonzero()
        node1 = pt.as_tensor_variable(node1, dtype=int)
        node2 = pt.as_tensor_variable(node2, dtype=int)
        W = pytensor.sparse.as_sparse_or_tensor_variable(W)

        pairwise_difference = (-1 / (2 * sigma**2)) * pt.sum(pt.square(value[node1] - value[node2]))
        zero_sum = (
            -0.5 * pt.pow(pt.sum(value) / (zero_sum_stdev * N), 2)
            - pt.log(pt.sqrt(2.0 * np.pi))
            - pt.log(zero_sum_stdev * N)
        )

        return check_parameters(pairwise_difference + zero_sum, sigma > 0, msg="sigma > 0")

Context for the issue:

The CAR prior allows the user to input a sparse matrix, but the ICAR prior does not have this functionality. For large spatial adjacency matrices, sparsity is critical to ensure efficient memory allocation.

The above solution makes minor revisions to run checks and generate the node tuple from a sparse matrix. It may be necessary to place it in an if statement, similar to the approach used in the CAR class. This version runs for me on a test dataset.

Copy link

welcome bot commented Jul 8, 2024

Welcome Banner]
🎉 Welcome to PyMC! 🎉 We're really excited to have your input into the project! 💖

If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 8, 2024

You want W wrapped in a pt.as_tensor_or_sparse (or something like that). The shape errors can go away or be made in the logp, unless shape is static in which case it can be done immediately.

@jfhawkin
Copy link
Author

@ricardoV94 Running the latest version of multivariate.py, I get the following error message:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 12
     10 # spatially dependent random effect with non-centered parameterization
     11 icar_sigma = pm.Exponential('icar_sigma', 1)
---> 12 phi = pm.ICAR("phi", W=adj_matrix)
     13 mu_icar = icar_sigma * phi
     15 # CBSA model

File ~/.conda/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/distribution.py:536, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, default_transform, *args, **kwargs)
    533     elif observed is not None:
    534         kwargs["shape"] = tuple(observed.shape)
--> 536 rv_out = cls.dist(*args, **kwargs)
    538 rv_out = model.register_rv(
    539     rv_out,
    540     name,
   (...)
    546     initval=initval,
    547 )
    549 # add in pretty-printing support

File ~/.conda/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/multivariate.py:2423, in ICAR.dist(cls, W, sigma, zero_sum_stdev, **kwargs)
   2420 sigma = pt.as_tensor_variable(sigma)
   2421 zero_sum_stdev = pt.as_tensor_variable(zero_sum_stdev)
-> 2423 return super().dist([W, node1, node2, N, sigma, zero_sum_stdev], **kwargs)

File ~/.conda/envs/pymc_env/lib/python3.11/site-packages/pymc/distributions/distribution.py:618, in Distribution.dist(cls, dist_params, shape, **kwargs)
    615     ndim_supp = cls.rv_op(*dist_params, **kwargs).owner.op.ndim_supp
    617 create_size = find_size(shape=shape, size=size, ndim_supp=ndim_supp)
--> 618 rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
    620 rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
    621 rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")

TypeError: ICARRV.__call__() got multiple values for argument 'size'

@jfhawkin
Copy link
Author

I found the error. The pymc 5.16.2 on conda-forge differs from the Github main repo in important ways. It's missing
size = kwargs.pop("size", None)
and several similar lines that address the positional/keyword argument issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants