Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast Hadamard Transform #1249

Merged
merged 13 commits into from
Jul 10, 2024
Merged

Fast Hadamard Transform #1249

merged 13 commits into from
Jul 10, 2024

Conversation

barronalex
Copy link
Collaborator

Add a fast Hadamard transform in Metal.

Supports n = m*2^k where m in (1, 12, 20, 28). (e.g. Llama 3 70B has a hidden size of 28672 = 28*1024).

Due to shared memory limits we support 2^k <= 8192 for FP32 and 2^k <= 16384 for FP16/BF16.

Planning to use this to enable low-bit, online quantization of the KV cache similar to Quarot/SpinQuant.

Benchmarks

We get close to full bandwidth for 2^k and half bandwidth for m*2^k (since we do those in two uploads). This is much faster than manually doing hadamard(N) @ x for non-trivial batch size.

bench_float32
bench_float16

@awni
Copy link
Member

awni commented Jul 3, 2024

Very nice!!

From an API perspective, I'm wondering if this should live in the fast namespace?

The primitive does not have any transforms implemented (which is fine). I guess the question is really if we intend to implement them eventually. If yes, then maybe it makes sense to keep it as is. But if no, I would consider putting it in the fast package for now and maybe do a transformable fallback (hadmard(n) @ x) if it's not too tedious.

@barronalex
Copy link
Collaborator Author

Agreed! The transforms are pretty simple so I added vjp/jvp/vmap.

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fantastic!

I left a few nitpicks, I think we can merge after that.

P.S.: Do we think we want a CPU implementation from the get go? Or we 'll simply add it later?

if (axes[0] == inputs[0].ndim() - 1) {
auto a = moveaxis(inputs[0], axes[0], 0, s);
auto b = hadamard_transform(a, scale_, s);
return {{moveaxis(b, 0, axes[0], s)}, axes};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to move it back, you can just return it with 0 ie return {{b}, {0}};

mlx/primitives.h Outdated
DEFINE_VMAP()
DEFINE_GRADS()
DEFINE_PRINT(Hadamard)
DEFINE_DEFAULT_IS_EQUIVALENT()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately the default is_equivalent is incorrect since we also need to check the scales. In general this is best left undefined unless certainly correct as it can cause quite hard to debug errors.

@@ -72,6 +72,7 @@ Operations
gather_qmm
greater
greater_equal
hadamard
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo hadamard_transform .

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a suggestion for the is_equivalent. Otherwise looks awesome!

@@ -3950,4 +3950,37 @@ bool View::is_equivalent(const Primitive& other) const {
return (dtype_ == a_other.dtype_);
}

std::pair<std::vector<array>, std::vector<int>> Hadamard::vmap(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::pair<std::vector<array>, std::vector<int>> Hadamard::vmap(
bool Hadamard::is_equivalent(const Primitive& other) const {
const Hadamard& h_other = static_cast<const Hadamard&>(other);
return scale_ == h_other.scale_;
}
std::pair<std::vector<array>, std::vector<int>> Hadamard::vmap(

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also needs the declaration in primitives.h of course.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Added this in the latest commit.

@barronalex barronalex merged commit a3c2873 into main Jul 10, 2024
3 checks passed
@barronalex barronalex deleted the ab-hadamard branch July 10, 2024 03:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants