-
Notifications
You must be signed in to change notification settings - Fork 932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CPU mx.linalg.cholesky_inverse and mx.linalg.tri_inv #1307
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks perfect except for the Device::cpu hardcoding.
mlx/linalg.cpp
Outdated
"matrices."); | ||
} | ||
|
||
array L_inv = tri_inv(L, upper, Device::cpu); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would probably still pass s
instead of Device::cpu
here. I understand we are leaving performance on the table by doing the matmul on CPU but generally all ops dispatch on the same exact stream.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense! Good motivation to have some more of linear algebra ops on GPU in the future
Does it make sense to make a Would that be usable in place of cholesky inv and tri inv for the use case you are going for? |
You could definitely implement the above with just
from scipy.linalg import solve_triangular
from scipy.linalg.lapack import strtri
N = 1024
A = np.random.normal(size=(N, N)).astype(np.float32)
A = A @ A.T
L = np.linalg.cholesky(A)
L_inv_ref = np.linalg.inv(L)
L_inv_i, _ = strtri(L, 1)
L_inv_t = solve_triangular(L, np.eye(N), lower=True)
np.testing.assert_allclose(L_inv_i, L_inv_ref, atol=1e-4)
np.testing.assert_allclose(L_inv_t, L_inv_ref, atol=1e-4)
time_fn(strtri, L, 1)
time_fn(solve_triangular, L, np.eye(N), lower=True)
time_fn(np.linalg.inv, L) Output:
Do you think that justifies having a separate Either way |
Right, I meant more like if you could replace explicitly computing L_inv (which may not be so numerically stable) with the use of solve_triangular. But I guess it depends on what you do with it? Like if we want But if we need the full inverse then we need it. If we keep this API, can we make the naming consistent. Like |
For I played around with the GPTQ use case and I don't think triangular solve is sufficient unfortunately. Definitely agree on the naming consistency. I think it's nice to have |
Sounds good to me! |
Add
mx.linalg.cholesky_inverse
with identical functionality to torch.cholesky_inverse.Uses lapack's optimized
strtri
routine for inverting triangular matrices. This makes it ~2x faster thanmlx.linalg.inv
and with less than half the memory usage.e.g for an N by N matrix
N = 8192
, before:4.59 sec
after:2.45 sec
.This is used by GPTQ so will be nice to have.