-
Notifications
You must be signed in to change notification settings - Fork 932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature complete Metal FFT #1102
Conversation
@@ -255,6 +257,96 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) { | |||
eval(inputs, out); | |||
} | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels a bit wrong having this as a primitive but I wasn't sure if there's a better way to it.
mlx/backend/metal/device.cpp
Outdated
@@ -357,7 +357,6 @@ MTL::Function* Device::get_function_( | |||
} | |||
|
|||
mtl_func_consts->release(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jagrit06 I was getting segfaults caused by this release when using function constants, but couldn't figure out the best place in the code to move it to. Any idea where it should fit in?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's a bug to release that, and deleting it is correct. https://github.com/bkaradzic/metal-cpp/blob/metal-cpp_macOS14.2_iOS17.2/README.md#memory-allocation-policy
Very impressive perf! Regarding the design, there is a big style difference from other MLX ops which we should change if possible. Basically you do the dispatch at the op-level rather than the Primitive level. I see how this might be easier as you have access to all the ops you need for different FFT algorithms, but I don't think we should do it this way. The compute graph should be more independent of the implementation details. Also, I don't think it makes sense for the FFT plans themselves should not be part of the compute graph (implementation detail). This redesign may require some changes to our existing backend to make it workable for you to use the requisite back-end ops from the FFT primitive's |
That makes sense to me, it did feel a little anti-pattern bloating out the graph but the MLX api is just really convenient! |
We have really bad support for doing stuff on arrays inside primitives (MLX wasn't really designed with that in mind 😓 ). But I think we can improve it a lot if needed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome perf and generally very nice work! Kudos!
I left a comment on BluesteinFFTSetup to maybe avoid the double precision math. I think it should be doable, let me know if I am missing something or if it feels too experimental.
// In numpy: | ||
// w_k = np.exp(-1j * np.pi / N * (np.arange(-N + 1, N) ** 2)) | ||
// w_q = np.fft.fft(1/w_k) | ||
// return w_k, w_q |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think of section IV.E of https://mc.stanford.edu/cgi-bin/images/7/75/SC08_FFT_on_GPUs.pdf . Would it solve our problem here to avoid double precision arithmetic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fig 6 is very promising :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks nice! I simplified the double precision part a bit so I think I'm going to keep it for now since it's not really an accuracy or performance bottleneck. Happy to revisit in the future though.
fd1c1a3
to
81096cf
Compare
OK that took a little while but I think the FFTs are in a reasonable state now:
|
mlx/fft.cpp
Outdated
// GPU scatter for complex64 is NYI | ||
in = | ||
scatter(tmp, std::vector<array>{}, in, std::vector<int>{}, Device::cpu); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do that with a slice_update
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds nicer -- I'll update it
mlx/backend/metal/kernels/fft.h
Outdated
#include "mlx/backend/metal/kernels/fft/radix.h" | ||
#include "mlx/backend/metal/kernels/fft/readwrite.h" | ||
#include "mlx/backend/metal/kernels/steel/defines.h" | ||
#include "mlx/backend/metal/kernels/utils.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is why you don't need to use the utils()
in the JIT, because its already included here by the preprocessor.
To keep the JIT source small, it would be better to move the includes that we already have in the JIT out of this file (e.g. kernels/utils.h
) and use the utils()
when constructing the JIT source.
You can include kernels/utils.h
in fft.metal
before you include fft.h
. I would just turn off clang formatting for that whole file and it won't mess with the include order.
|
||
#include <metal_common> | ||
|
||
#include "mlx/backend/metal/kernels/utils.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note you should also remove the include here.
mlx/backend/metal/kernels/utils.h
Outdated
METAL_FUNC float2 complex_mul(float2 a, float2 b) { | ||
return float2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); | ||
} | ||
|
||
// Complex mul followed by conjugate | ||
METAL_FUNC float2 complex_mul_conj(float2 a, float2 b) { | ||
return float2(a.x * b.x - a.y * b.y, -a.x * b.y - a.y * b.x); | ||
} | ||
|
||
// Compute an FFT twiddle factor | ||
METAL_FUNC float2 get_twiddle(int k, int p) { | ||
float theta = -2.0f * k * M_PI_F / p; | ||
|
||
float2 twiddle = {metal::fast::cos(theta), metal::fast::sin(theta)}; | ||
return twiddle; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the only reason you are using utils.h
is for these, it might be cleaner to just put those in fft.h
instead. I think they also just fit better in fft.h if it works.. we have the complex64_t which should be used in general for complex muls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed it's definitely a bit confusing otherwise. I've removed the utils.h
import.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀 🚀
Proposed changes
A feature complete GPU FFT implementation in Metal.
Supports
n < 2^20
fft
,ifft
,rfft
,irfft
fft2
,ifft2
,rfft2
,irfft2
,fftn
,ifftn
,rfftn
,irfftn
Algorithms
n
where all prime factorsp
have2 =< p <= 13
.n
with one prime factorp > 13
wherep-1
can be computed via Stockholm.n
.n > 4096
when the FFT can no longer be done purely in GPU shared memory.Performance
For
2 <= n < 512
, 1D complex to complex FFTs on my M1 Max, the average bandwidths are:So this implementation is about 2.3x faster than MPS on average and about 27x faster than CPU MLX which uses
pocketfft
.This implementation does specialize for different values of n with Metal function constants so it will have more overhead than MPS on the first call for new Stockham/Rader sizes.