Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement specialized MvNormal density based on precision matrix #7345

Merged
merged 1 commit into from
Aug 3, 2024

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jun 3, 2024

Description

This PR is exploring a specialized logp for a MvNormal (and possible MvStudentT) parametrized directly in terms of tau. According to common model implementation looks like:

import pymc as pm
import numpy as np

A = np.array([
    [0, 1, 1],
    [1, 0, 1], 
    [1, 1, 0]
])
D = A.sum(axis=-1)
np.testing.assert_allclose(A, A.T), "should be symmetric"

with pm.Model() as m:
    tau = pm.InverseGamma("tau", 1, 1)
    alpha = pm.Beta("alpha", 10, 10)
    Q = tau * (D - alpha * A)
    y = pm.MvNormal("y", mu=np.zeros(3), tau=Q)

TODO (some are optional for this PR)

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

CC @theorashid @elizavetasemenova


📚 Documentation preview 📚: https://pymc--7345.org.readthedocs.build/en/7345/

@ricardoV94
Copy link
Member Author

Implementation checks may fail until pymc-devs/pytensor#799 is fixed

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@ricardoV94 ricardoV94 marked this pull request as ready for review June 21, 2024 15:00
@ricardoV94
Copy link
Member Author

Benchmark code

import pymc as pm
import numpy as np
import pytensor
import pytensor.tensor as pt

rng = np.random.default_rng(123)

n = 100
L = rng.uniform(low=0.1, high=1.0, size=(n,n))
Sigma = L @ L.T
Q_test = np.linalg.inv(Sigma)
x_test = rng.normal(size=n)

with pm.Model(check_bounds=False) as m:
    Q = pm.Data("Q", Q_test)
    x = pm.MvNormal("x", mu=pt.zeros(n), tau=Q)

logp_fn = m.compile_logp().f
logp_fn.trust_input=True
print("logp")
pytensor.dprint(logp_fn)


dlogp_fn = m.compile_dlogp().f
dlogp_fn.trust_input=True
print("dlogp")
pytensor.dprint(dlogp_fn)

np.testing.assert_allclose(logp_fn(x_test), np.array(-1789.93662205))

np.testing.assert_allclose(np.sum(dlogp_fn(x_test) ** 2), np.array(18445204.8755109), rtol=1e-6)

# Before: 2.66 ms
# After: 1.31 ms
%timeit -n 1000 logp_fn(x_test)

# Before: 2.45 ms
# After: 72 µs
%timeit -n 1000 dlogp_fn(x_test)

Copy link

codecov bot commented Jun 25, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.20%. Comparing base (8eaa9be) to head (8550f01).

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7345      +/-   ##
==========================================
+ Coverage   92.18%   92.20%   +0.01%     
==========================================
  Files         103      103              
  Lines       17263    17301      +38     
==========================================
+ Hits        15914    15952      +38     
  Misses       1349     1349              
Files Coverage Δ
pymc/distributions/multivariate.py 93.10% <100.00%> (+0.25%) ⬆️
pymc/logprob/rewriting.py 89.18% <100.00%> (+0.17%) ⬆️

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jun 25, 2024

Final question is just whether we want / can do a similar thing for the MvStudentT. Otherwise it's ready to merge on my end

CC @elizavetasemenova

[value] = value
k = value.shape[-1]
delta = value - mean
det_sign, logdet_tau = pt.nlinalg.slogdet(tau)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if the slogdet is necessarily a good idea here. Internally at least in numpy this does a lu factorization, which is I think takes theoretically about twice as long as the cholesky, and should usually be less stable. (But I think the performance can differ a lot based on number of threads and blas). So for a matrix that is not constant this might be slower than the usual MvNormal right now.

So I think it is better to use the cholesky decomposition here and use that to get the logdet.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the example above the matrix is not constant, so you can use it to benchmark

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And it twas 2x faster logp and 100x faster dlogp on my crappy PC. Are you skeptical of those numbers, or you think we can do even better?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here would be some code with a non-constant matrix, but I get a NotImplementederror for the grad of Blockwise?

import pymc as pm
import numpy as np
import pytensor
import pytensor.tensor as pt

rng = np.random.default_rng(123)

n = 100
L = rng.uniform(low=0.1, high=1.0, size=(n,n))
Sigma = L @ L.T
Q_test = np.linalg.inv(Sigma)
x_test = rng.normal(size=n)
v_test = rng.normal(size=n)

with pm.Model(check_bounds=False) as m:
    Q = pm.Data("Q", Q_test)
    v = pm.Normal("v", shape=n)
    x = pm.MvNormal("x", mu=pt.zeros(n), tau=Q + v[None, :] * v[:, None])

logp_fn = m.compile_logp().f
logp_fn.trust_input=True
print("logp")
#pytensor.dprint(logp_fn)


dlogp_fn = m.compile_dlogp().f
dlogp_fn.trust_input=True
print("dlogp")
#pytensor.dprint(dlogp_fn)

#np.testing.assert_allclose(logp_fn(x_test, v_test), np.array(-1789.93662205))

#np.testing.assert_allclose(np.sum(dlogp_fn(x_test, v_test) ** 2), np.array(18445204.8755109), rtol=1e-6)
with threadpoolctl.threadpool_limits(1):

    # Before: 2.66 ms
    # After: 1.31 ms
    %timeit logp_fn(x_test, v_test)
    
    # Before: 2.45 ms
    # After: 72 µs
    %timeit dlogp_fn(x_test, v_test)

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your expression moves the log above the diagonal, which makes sense, but I think the minus in -2 * is wrong?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The performance of the original code looks I think pretty bad by the way:

#353 μs ± 4.53 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
#1.61 ms ± 32.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

If everything is implemented well I don't think there is a good reason why the gradient should be much slower than just the logp. Factoring the matrx onces should be enough.
I guess this is because we compute the pullback of the cholesky or so, even though the pullback of the logdet is actually pretty easy if you already have a factorization...
There might be a nice usecase for an OpFromGraph and overwriting the forward values hiding here somewhere...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be 40us (1.2x faster) with the cholesky now that I only do log on the diagonal. Funny enough it's fusing the log and the sum, which is our only elemwise + reduce fusion we have :)

logdet_tau = 2 * pt.log(pt.diagonal(pt.linalg.cholesky(tau), axis1=-2, axis2=-1)).sum()
%env OMP_NUM_THREADS=1

import pymc as pm
import numpy as np
import pytensor
import pytensor.tensor as pt

rng = np.random.default_rng(123)

n = 100
L = rng.uniform(low=0.1, high=1.0, size=(n,n))
Sigma = L @ L.T
Q_test = np.linalg.inv(Sigma)
x_test = rng.normal(size=n)

with pm.Model(check_bounds=False) as m:
    Q = pm.Data("Q", Q_test)
    x = pm.MvNormal("x", mu=pt.zeros(n), tau=Q)

logp_fn = m.compile_logp().f
logp_fn.trust_input=True
print("logp")
pytensor.dprint(logp_fn)

dlogp_fn = m.compile_dlogp().f
dlogp_fn.trust_input=True
print("dlogp")
pytensor.dprint(dlogp_fn)

np.testing.assert_allclose(logp_fn(x_test), np.array(-1789.93662205))

np.testing.assert_allclose(np.sum(dlogp_fn(x_test) ** 2), np.array(18445204.8755109), rtol=1e-6)

# With slogdet: 236 µs ± 31.7 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# With cholesky logdet: 192 µs ± 25.6 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit -n 10000 logp_fn(x_test)

# With slogdet: 29.8 µs ± 3.19 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# With cholesky logdet: 32.5 µs ± 4.51 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit -n 10000 dlogp_fn(x_test)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushed a commit with the cholesky factorization

logp = -0.5 * (k * pt.log(2 * pt.pi) - logdet + quadratic_form)
return check_parameters(
logp,
(cholesky_diagonal > 0).all(-1),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is the right check for posdef-ness? @aseyboldt

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cholesky will simply fail (throw an exception) if the matrix is not posdef. I think we check for that in the perform method and return nan, but I don't know what for instance jax will do.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay so we could check for nan? I'll see what JAX does

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess all(nan) will also be false so this ends up being the same thing? Or does it evaluate to nan...?

I'll check

@ricardoV94
Copy link
Member Author

ricardoV94 commented Jul 12, 2024

Last benchmarks, running the following script:

%env OMP_NUM_THREADS=1

USE_TAU = True

import pymc as pm
import numpy as np
import pytensor
import pytensor.tensor as pt

rng = np.random.default_rng(123)

n = 100
L = rng.uniform(low=0.1, high=1.0, size=(n,n))
Sigma = L @ L.T
Q_test = np.linalg.inv(Sigma)
x_test = rng.normal(size=n)

with pm.Model(check_bounds=False) as m:
    Q = pm.Data("Q", Q_test)
    if USE_TAU:
        x = pm.MvNormal("x", mu=pt.zeros(n), tau=Q)
    else:
        x = pm.MvNormal("x", mu=pt.zeros(n), cov=Q)

logp_fn = m.compile_logp().f
logp_fn.trust_input=True
print("logp")
pytensor.dprint(logp_fn)

dlogp_fn = m.compile_dlogp().f
dlogp_fn.trust_input=True
print("dlogp")
pytensor.dprint(dlogp_fn)

%timeit -n 10000 logp_fn(x_test)
%timeit -n 10000 dlogp_fn(x_test)

USE_TAU = TRUE, without optimization:

logp
Composite{((i2 - (i0 * i1)) - i3)} [id A] 'x_logprob' 9
 ├─ 0.5 [id B]
 ├─ DropDims{axis=0} [id C] 8
 │  └─ CAReduce{Composite{(i0 + sqr(i1))}, axis=1} [id D] 7
 │     └─ Transpose{axes=[1, 0]} [id E] 5
 │        └─ SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=2} [id F] 3
 │           ├─ Cholesky{lower=True, destructive=False, on_error='nan'} [id G] 2
 │           │  └─ MatrixInverse [id H] 1
 │           │     └─ Q [id I]
 │           └─ ExpandDims{axis=1} [id J] 0
 │              └─ x [id K]
 ├─ -91.89385332046727 [id L]
 └─ CAReduce{Composite{(i0 + log(i1))}, axes=None} [id M] 6
    └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id N] 4
       └─ Cholesky{lower=True, destructive=False, on_error='nan'} [id G] 2
          └─ ···

dlogp
DropDims{axis=1} [id A] '(dx_logprob/dx)' 7
 └─ SolveTriangular{trans=0, unit_diagonal=False, lower=False, check_finite=True, b_ndim=2} [id B] 6
    ├─ Transpose{axes=[1, 0]} [id C] 5
    │  └─ Cholesky{lower=True, destructive=False, on_error='nan'} [id D] 2
    │     └─ MatrixInverse [id E] 1
    │        └─ Q [id F]
    └─ Neg [id G] 4
       └─ SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=2} [id H] 3
          ├─ Cholesky{lower=True, destructive=False, on_error='nan'} [id D] 2
          │  └─ ···
          └─ ExpandDims{axis=1} [id I] 0
             └─ x [id J]

541 µs ± 56.8 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
503 µs ± 41.3 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

USE_TAU = True with optimization

logp
Composite{(i4 * ((i2 - (i0 * i1)) + i3))} [id A] 'x_logprob' 11
 ├─ 2.0 [id B]
 ├─ CAReduce{Composite{(i0 + log(i1))}, axes=None} [id C] 10
 │  └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id D] 9
 │     └─ Cholesky{lower=True, destructive=False, on_error='raise'} [id E] 8
 │        └─ Q [id F]
 ├─ 183.78770664093454 [id G]
 ├─ DropDims{axis=0} [id H] 7
 │  └─ CGemv{inplace} [id I] 6
 │     ├─ AllocEmpty{dtype='float64'} [id J] 5
 │     │  └─ 1 [id K]
 │     ├─ 1.0 [id L]
 │     ├─ ExpandDims{axis=0} [id M] 4
 │     │  └─ x [id N]
 │     ├─ CGemv{inplace} [id O] 3
 │     │  ├─ AllocEmpty{dtype='float64'} [id P] 2
 │     │  │  └─ Shape_i{1} [id Q] 1
 │     │  │     └─ Q [id F]
 │     │  ├─ 1.0 [id L]
 │     │  ├─ Transpose{axes=[1, 0]} [id R] 'Q.T' 0
 │     │  │  └─ Q [id F]
 │     │  ├─ x [id N]
 │     │  └─ 0.0 [id S]
 │     └─ 0.0 [id S]
 └─ -0.5 [id T]

dlogp
CGemv{inplace} [id A] '(dx_logprob/dx)' 5
 ├─ CGemv{inplace} [id B] 4
 │  ├─ AllocEmpty{dtype='float64'} [id C] 3
 │  │  └─ Shape_i{0} [id D] 2
 │  │     └─ Q [id E]
 │  ├─ 1.0 [id F]
 │  ├─ Q [id E]
 │  ├─ Mul [id G] 1
 │  │  ├─ [-0.5] [id H]
 │  │  └─ x [id I]
 │  └─ 0.0 [id J]
 ├─ -0.5 [id K]
 ├─ Transpose{axes=[1, 0]} [id L] 'Q.T' 0
 │  └─ Q [id E]
 ├─ x [id I]
 └─ 1.0 [id F]

160 µs ± 34.1 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
15.9 µs ± 2.29 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

For reference: USE_TAU = False before and after (unchanged)

Before:

...
260 µs ± 13 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
275 µs ± 19.5 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

After:

...
259 µs ± 46.3 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
275 µs ± 30.6 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Summary

tau used to be ~2x slower logp and ~2x slower dlogp vs direct cov due to the extra MatrixInverse
tau is now ~2x faster logp and ~20x faster dlogp vs direct cov
tau total speedup: ~4x faster logp and ~40x faster dlogp

@ricardoV94 ricardoV94 force-pushed the mvnormal_precision_logp branch 2 times, most recently from ab6c1a3 to f50b56d Compare July 12, 2024 15:45
@ricardoV94
Copy link
Member Author

numpyro tests are failing probably because it now requires the more recent versions of jax. should be fixed by #7407

@ricardoV94
Copy link
Member Author

@aseyboldt any thing that should block this PR?

Co-authored-by: theorashid <theoaorashid@gmail.com>
Co-authored-by: elizavetasemenova <elizaveta.p.semenova@gmail.com>
Co-authored-by: aseyboldt <aseyboldt@users.noreply.github.com>
@aseyboldt
Copy link
Member

Looks good. I think it is possible that we could further improve the MvNormal in both parametrizations, but this is definetly an improvement as it is.
Most of all I think we should do the same for the CholeskyMvNormal. Looks like we are just computing the cov just to re-compute the cholesky again? At some point we did make use of the cholesky directly, but I guess that got lost in a refactor wtih the pytensor RVs?

@ricardoV94
Copy link
Member Author

Looks good. I think it is possible that we could further improve the MvNormal in both parametrizations, but this is definetly an improvement as it is.
Most of all I think we should do the same for the CholeskyMvNormal. Looks like we are just computing the cov just to re-compute the cholesky again? At some point we did make use of the cholesky directly, but I guess that got lost in a refactor wtih the pytensor RVs?

We don't recompute the Cholesky, we have rewrites to remove it and even a specific test for it:

def test_mvnormal_no_cholesky_in_model_logp():

@ricardoV94 ricardoV94 merged commit 48e56c3 into pymc-devs:main Aug 3, 2024
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants