Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU mx.linalg.cholesky_inverse and mx.linalg.tri_inv #1307

Merged
merged 3 commits into from
Aug 8, 2024

Conversation

barronalex
Copy link
Collaborator

Add mx.linalg.cholesky_inverse with identical functionality to torch.cholesky_inverse.

Uses lapack's optimized strtri routine for inverting triangular matrices. This makes it ~2x faster than mlx.linalg.inv and with less than half the memory usage.

e.g for an N by N matrix N = 8192, before: 4.59 sec after: 2.45 sec.

This is used by GPTQ so will be nice to have.

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks perfect except for the Device::cpu hardcoding.

mlx/linalg.cpp Outdated
"matrices.");
}

array L_inv = tri_inv(L, upper, Device::cpu);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably still pass s instead of Device::cpu here. I understand we are leaving performance on the table by doing the matmul on CPU but generally all ops dispatch on the same exact stream.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense! Good motivation to have some more of linear algebra ops on GPU in the future

@awni
Copy link
Member

awni commented Aug 5, 2024

Does it make sense to make a solve_triangular (like in scipy.linalg and torch?

Would that be usable in place of cholesky inv and tri inv for the use case you are going for?

@barronalex
Copy link
Collaborator Author

barronalex commented Aug 5, 2024

You could definitely implement the above with just solve_triangular and it would be more general which is nice.

scipy uses the same lapack ops underneath and from my quick test it seems like computing $L^{-1}$ with solve_triangular instead of strtri is about 17% slower:

from scipy.linalg import solve_triangular
from scipy.linalg.lapack import strtri

N = 1024
A = np.random.normal(size=(N, N)).astype(np.float32)
A = A @ A.T
L = np.linalg.cholesky(A)

L_inv_ref = np.linalg.inv(L)
L_inv_i, _ = strtri(L, 1)
L_inv_t = solve_triangular(L, np.eye(N), lower=True)
np.testing.assert_allclose(L_inv_i, L_inv_ref, atol=1e-4)
np.testing.assert_allclose(L_inv_t, L_inv_ref, atol=1e-4)

time_fn(strtri, L, 1)
time_fn(solve_triangular, L, np.eye(N), lower=True)
time_fn(np.linalg.inv, L)

Output:

Timing function strtri ... 5.88889 msec
Timing solve_triangular ... 6.88527 msec
Timing inv ... 41.91644 msec

Do you think that justifies having a separate tri_inv API?

Either way solve_triangular seems generally useful so happy to add it.

@awni
Copy link
Member

awni commented Aug 5, 2024

Right, I meant more like if you could replace explicitly computing L_inv (which may not be so numerically stable) with the use of solve_triangular. But I guess it depends on what you do with it? Like if we want L_inv * b then we can do a triangular solve for L x = b?

But if we need the full inverse then we need it.

If we keep this API, can we make the naming consistent. Like inv vs inverse. Another option (which I might prefer) is to not add cholesky_inverse since it's quite a simple add on to tri_inv so maybe easy enough for now to do in user code?

@barronalex
Copy link
Collaborator Author

For cholesky_inverse I think we do need $L^{-1}$ directly since we construct $A^{-1} = L^{-T}L^{-1}$.

I played around with the GPTQ use case and I don't think triangular solve is sufficient unfortunately.

Definitely agree on the naming consistency. I think it's nice to have cholesky_inv in the API since it's in PyTorch but I could be convinced to leave it out too.

@awni
Copy link
Member

awni commented Aug 8, 2024

I think it's nice to have cholesky_inv in the API since it's in PyTorch

Sounds good to me!

@barronalex barronalex merged commit 32668a7 into main Aug 8, 2024
3 checks passed
@barronalex barronalex deleted the ab-cholesky-inv branch August 8, 2024 22:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants