Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement unconstraining transform for LKJCorr #7380

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

johncant
Copy link

@johncant johncant commented Jun 21, 2024

I've ported this bijector from tensorflow and added to LKJCorr. This ensures that initial samples drawn from LKJCorr are positive definite, which fixes #7101 . Sampling now completes successfully with no divergences.

There are several parts I'm not comfortable with:

@fonnesbeck @twiecki @jessegrabowski @velochy - please could you take a look? I would like to make sure that this fix makes sense before adding tests and making the linters pass.

Notes:

  • Tests not yet written, linters not yet ran
  • The original tensorflow bijector is defined in the opposite sense to pymc transforms, i.e. forward in tensorflow_probability is backward in pymc
  • The original tensorflow bijector produces cholesky factors, not actual correlation matrices, so in this implementation, we have to do a cholesky decomposition in the forward transform.
  • In the tensorflow bijector, the triagonal elements of a matrix are filled in a clockwise spiral, as opposed to numpy which defines indices in a row-major order.

Description

Backward method

  1. Start with identity matrix and fill lower triangular elements with unconstrained real numbers.
  2. Normalize each row so the L-2 norm is 1
  3. This is now a Cholesky factor that will always result in positive definite correlation matrices

Forward method

  1. Reconstruct the correlation matrix from its upper triangular elements
  2. Perform cholesky decomposition to obtain L
  3. The diagonal elements of L are multipliers we used to normalize the other elements.
  4. Extract those diagonal elements and divide to undo the backward method

log_jac_det

This was quite complicated to implement, so I used the symbolic jacobian.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7380.org.readthedocs.build/en/7380/

Copy link

welcome bot commented Jun 21, 2024

Thank You Banner]
💖 Thanks for opening this pull request! 💖 The PyMC community really appreciates your time and effort to contribute to the project. Please make sure you have read our Contributing Guidelines and filled in our pull request template to the best of your ability.

@johncant johncant changed the title Fix #7101 by implementing a transform to that LKJCorr samples are positive definite Fix #7101 by implementing a transform to ensure that LKJCorr samples are positive definite Jun 21, 2024

# Are the diagonals always guaranteed to be positive?
# I don't know, so we'll use abs
row_norms = 1/pt.abs(pt.diag(chol))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, always positive. You don't need abs here

)

def _jacobian(self, value, *inputs):
return pt.jacobian(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pt.jacobian can be quite expensive, because it requires us to loop over every input and compute the associated symbolic gradients. There's a closed form solution for the log-det jacobian in the TFP code, so you can eliminate this method and implement the closed form log-det jac:

    n = ps.shape(y)[-1]
    return -tf.reduce_sum(
        tf.range(2, n + 2, dtype=y.dtype) * tf.math.log(tf.linalg.diag_part(y)),
        axis=-1)

diag_part would just be pt.diagonal(y, axis1=-2, axis2=-1). That will account for potential batching on y. So something like:

n = y.shape[-1]
return -(pt.arange(2, n+2, dtype=y.dtype) * pt.log(pt.diagonal(y, axis1=-2, axis2=-1))).sum(axis=-1)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you point me to some info on how that's derived? Going to need to modify it to work with the correlation matrix.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, found in the comments in TFP _inverse_log_det_jacobian

row_indices, col_indices = np.tril_indices(self.n, -1)
return (
pytensor.shared(row_indices),
pytensor.shared(col_indices)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to save these as shared variables, you can use the numpy indices directly. Making the numpy indices is pretty cheap, I'm not sure its worth it to cache them


return unconstrained[self.tril_r_idxs, self.tril_c_idxs]

def backward(self, value, *inputs, foo=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to check that these functions match the expected outputs from TFP. I used the test case from the tfp docs and got the wrong values -- array([0.89442719, 0.81649658, 0.91287093]) vs the reference solution

array([[1.        , 0.        , 0.        ],
       [0.70710678, 0.70710678, 0.        ],
       [0.66666667, 0.66666667, 0.33333333]])

You did some extra research so I might be missing something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like this matches tfp:

    def backward(self, value, *inputs):
        """
        Convert unconstrained real numbers to the off-diagonal elements of the
        cholesky decomposition of a correlation matrix.
        """
        def unpack_upper_tril_with_eye_diag(x, core_shape):
            """1D allocation case"""
            return pt.set_subtensor(pt.eye(core_shape)[*np.tril_indices(core_shape, k=-1)], x[::-1])
        
        value = pt.as_tensor_variable(value)
        core_shape = value.type.shape[-1]
        
        # Vectorize the 1D case to handle potential batch dimensions
        out = pt.vectorize(partial(unpack_upper_tril_with_eye_diag, core_shape=core_shape), '(n)->(n,n)')(value)
        
        # Vector L2 norm without .real call to speed things up a bit
        norm = pt.sqrt(pt.sum((out ** 2), axis=-1, keepdims=True))
        return out / norm

Copy link
Author

@johncant johncant Jun 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for that code above, there's some great stuff I can reuse.

Need to address this comment first, since actually working in the first place is really fundamental. Here's a notebook that demonstrates that this implementation does replicate the original reference implementation from TFP: https://colab.research.google.com/drive/1BBNNfBUNJPGT_7MxVboTqvRegJ-TUamc?usp=sharing

Here's why it didn't work for you:

  • PyMC implementation needs to output the upper triangular elements of the correlation matrix, whereas the TFP implementation outputs a Cholesky factor.
  • Differences in indexing off-diagonal elements. TFP actually fills in off-diagonal elements in a clockwise spiral, whereas np.triu_indices is row major. I notice you reverse the elements in the code above which is correct for 3x3 but not for larger matrices.

Copy link
Member

@jessegrabowski jessegrabowski Jun 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got access denied to the colab :)

I'm not surprised my code doesn't work, but I'm glad you know why. In the general case, tensorflow concatenates a reflected copy of the array to itself then reshapes and masks out the lower/upper triangle -- see here if you haven't already. There's no reason why we couldn't just do that.

I'm not sure that we need to copy their output 1:1 -- after all, the important thing is that we can go from unconstrained samples to a valid cholesky decomposed correlation matrix. Is the order we put the numbers into the matrix relevant? I'm not sure, but my instinct is no. On the other hand, if we copy 1:1 we can be sure it's right.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyMC implementation needs to output the upper triangular elements of the correlation matrix, where the TFP implementation outputs a Cholesky factor.

Are you sure? I thought the upper triangular elements are the cholesky factorized correlation matrix. If you're right though we just need to add a matmul to the end right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies - Link is now open with Viewer permission and I've made you an Editor.

I don't think the order we insert the off-diagonal elements into an array is very important, but it is needed in order to compare results between this implementation and the one in TFP. I would suggest sticking with np.triu_indices here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure? I thought the upper triangular elements are the cholesky factorized correlation matrix. If you're right though we just need to add a matmul to the end right?

Yes, you can see this by looking at the implementation of LKJCorr. I originally thought the same thing, implemented the transform accordingly, then was surprised that non-posdef matrices were generated. 🤦

Copy link
Member

@jessegrabowski jessegrabowski Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Sounds like you have a good handle on things. I think what would be a really important next step would be to add a test that your implementation correctly makes a round trip from $\mathbb R \to \Omega \to \mathbb R$, where $\Omega$ is the set of correlation matrices.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome - I can do that

self.n = n
self.m = int(n*(n-1)/2) # number of off-diagonal elements
self.tril_r_idxs, self.tril_c_idxs = self._generate_tril_indices()
self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below, not sure we need to cache these. __init__ is probably unnecessary

jac = self._jacobian(value)
return pt.log(pt.linalg.det(jac))

def forward(self, value, *inputs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below. I'm pretty sure this needs to go from matrix to vector (to match the tfp case) @junpenglao might know for sure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 it is better for the unbounded being a vector.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I am a bit confused and don't understand what you mean.

Specifically, do you mean this function needs to work along the last axis for arrays of arbitrary number of dimensions, and that the current iteration assumes that value will only have dimension 1?

@@ -1579,7 +1579,9 @@ def logp(value, n, eta):

@_default_transform.register(_LKJCorr)
def lkjcorr_default_transform(op, rv):
return MultivariateIntervalTransform(-1.0, 1.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you delete this transform class as well? It was a (wrong) patch to the problem you're solving

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can do. Just to confirm, you don't consider MultivariateIntervalTransform to be part of pymc's public API?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, can be removed without worries

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok - great

self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices()

def _generate_tril_indices(self):
row_indices, col_indices = np.tril_indices(self.n, -1)
Copy link
Member

@ricardoV94 ricardoV94 Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it matters but there is a pt.tril_indices and pt.triu_indices so no need to eval n. If it's already restricted to be constant elsewhere (like the logp), then it's fine either way

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's good practice to use the pt version, even if n is fixed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally tried to use the pt version, but one of the function calls required constant values. However, I've made so many changes, that might no longer be the case. I'll try the pt version again and see if I can get it to work.

@ricardoV94 ricardoV94 changed the title Fix #7101 by implementing a transform to ensure that LKJCorr samples are positive definite Implement unconstraining transform for LKJCorr Jun 24, 2024
@johncant
Copy link
Author

Hi, It's unlikely I'm going to have any time to work on this for the next 6 months. The hardest part is coming up with a closed form solution for log_det_jac, which I don't think I'm very close to doing.

@twiecki
Copy link
Member

twiecki commented Jul 30, 2024

Thanks for the update @johncant and for pushing this as far as you did.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: LKJCorr breaks when used as covariance with MvNormal
5 participants