-
Notifications
You must be signed in to change notification settings - Fork 932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fast Hadamard Transform #1249
Fast Hadamard Transform #1249
Conversation
Very nice!! From an API perspective, I'm wondering if this should live in the fast namespace? The primitive does not have any transforms implemented (which is fine). I guess the question is really if we intend to implement them eventually. If yes, then maybe it makes sense to keep it as is. But if no, I would consider putting it in the fast package for now and maybe do a transformable fallback ( |
Agreed! The transforms are pretty simple so I added |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fantastic!
I left a few nitpicks, I think we can merge after that.
P.S.: Do we think we want a CPU implementation from the get go? Or we 'll simply add it later?
mlx/primitives.cpp
Outdated
if (axes[0] == inputs[0].ndim() - 1) { | ||
auto a = moveaxis(inputs[0], axes[0], 0, s); | ||
auto b = hadamard_transform(a, scale_, s); | ||
return {{moveaxis(b, 0, axes[0], s)}, axes}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to move it back, you can just return it with 0 ie return {{b}, {0}};
mlx/primitives.h
Outdated
DEFINE_VMAP() | ||
DEFINE_GRADS() | ||
DEFINE_PRINT(Hadamard) | ||
DEFINE_DEFAULT_IS_EQUIVALENT() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately the default is_equivalent
is incorrect since we also need to check the scales. In general this is best left undefined unless certainly correct as it can cause quite hard to debug errors.
docs/src/python/ops.rst
Outdated
@@ -72,6 +72,7 @@ Operations | |||
gather_qmm | |||
greater | |||
greater_equal | |||
hadamard |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo hadamard_transform
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a suggestion for the is_equivalent
. Otherwise looks awesome!
@@ -3950,4 +3950,37 @@ bool View::is_equivalent(const Primitive& other) const { | |||
return (dtype_ == a_other.dtype_); | |||
} | |||
|
|||
std::pair<std::vector<array>, std::vector<int>> Hadamard::vmap( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::pair<std::vector<array>, std::vector<int>> Hadamard::vmap( | |
bool Hadamard::is_equivalent(const Primitive& other) const { | |
const Hadamard& h_other = static_cast<const Hadamard&>(other); | |
return scale_ == h_other.scale_; | |
} | |
std::pair<std::vector<array>, std::vector<int>> Hadamard::vmap( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also needs the declaration in primitives.h
of course.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Added this in the latest commit.
Add a fast Hadamard transform in Metal.
Supports
n = m*2^k
wherem in (1, 12, 20, 28)
. (e.g. Llama 3 70B has a hidden size of28672 = 28*1024
).Due to shared memory limits we support
2^k <= 8192
for FP32 and2^k <= 16384
for FP16/BF16.Planning to use this to enable low-bit, online quantization of the KV cache similar to Quarot/SpinQuant.
Benchmarks
We get close to full bandwidth for
2^k
and half bandwidth form*2^k
(since we do those in two uploads). This is much faster than manually doinghadamard(N) @ x
for non-trivial batch size.