-
Notifications
You must be signed in to change notification settings - Fork 932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JIT compile option for binary minimization #1091
Conversation
mlx/backend/metal/unary.cpp
Outdated
[[kernel]] void {0}_v( | ||
device const {1}* in, | ||
device {2}* out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I put the kernels here (binary kernels in the binary.cpp
, and so on). Maybe better to put them in a different file in kernels/
that gets included? Not sure if either of you have a preference there @angeloskath @jagrit06
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kinda like it like this. If we do want to do this for more complicated kernels maybe we need a solution like the preamble but for unary, binary, ternary this is pretty great imho.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it will likely be a combination of the two. Anything that needs to be formatted at runtime will likely be like this. And we can try to keep that bit to a minimum by having it be just the instantiations essentially.
The other stuff I will probably put in preambles. But I don't think it makes sense to do it as one giant preamble since that won't scale. So likely I will change the preamble / include stuff to be a little more modular
Benchmarks: No degradation in token generation:
Transformer training:
LeNet training:
MNIST:
|
61d2f8b
to
1ead2a5
Compare
auto& d = metal::device(s.device); | ||
|
||
std::string kernel_name = (contig ? "v" : "g") + op + type_to_name(out); | ||
auto kernel = get_unary_kernel(d, kernel_name, out); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@angeloskath @jagrit06 this is where the function gets linked in differently depending on the compile flag MLX_METAL_JIT
. You can see the different defnitions in jit_kernels.cpp
and nojit_kernels.cpp
.
mlx/backend/metal/ternary.cpp
Outdated
|
||
MTL::ComputePipelineState* kernel; | ||
|
||
if constexpr (mlx_metal_jit()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jagrit06 @angeloskath this is an example using the constexpr
to figure out if we are JITing or not. It's not as messy as I thought it would be provided the right helper utilities.
mlx/backend/metal/ternary.cpp
Outdated
#include "mlx/backend/metal/utils.h" | ||
#include "mlx/primitives.h" | ||
|
||
#ifndef MLX_METAL_JIT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a minor downside of going the constexpr route. But the only place we really need to use the preprocessor.
bd04991
to
f53af4a
Compare
@jagrit06 @angeloskath I think this is ready for review. |
.circleci/config.yml
Outdated
command: | | ||
source env/bin/activate | ||
cd build/ | ||
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel -DBUILD_SHARED_LIBS=ON -DMLX_BUILD_CPU=ON -DMLX_BUILD_SAFETENSORS=OFF -DMLX_BUILD_GGUF=OFF -DMLX_METAL_JIT=ON |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious about MinSizeRel
(which I believe favors size to code speed), vs Release
in terms of size. Also, have you tried strip?
https://stackoverflow.com/questions/38675403/how-to-config-cmake-for-strip-file/38676023
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't make a difference for what I tried for the GPU back-end. It might for the CPU, I haven't checked. But it's really meant as an option for deploying mostly on the GPU (when you want a really small CPU binary).
I haven't checked strip.. let me try and see if it reduces the size anymore.
For a review, the main thing to look at is:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks great!
In MLX tradition one can take the simple route and not provide a JIT option for a new kernel which means no change to the usage at all. That means that even for us, we can add kernels as we 're used to and add a JIT option later.
kernels/reduction/ops.h | ||
) | ||
make_jit_source(scatter) | ||
make_jit_source(gather) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gather, scatter always jitted I guess. Honestly, it seems fitting. Do you foresee speedups as well for these going down this route?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.. I kept them in the JIT for a few reasons:
- Kernels are pretty simple
- They take up a disproportionate amount of space for kernels which are mostly never used (e.g nidx > 2). I think 30Mb of the metal library.
- With the JIT we don't need to worry about a hard ceiling on the number of arguments (up to the buffer limit) which is kind of nice albeit perhaps not that useful.
- There is negligible additional change to cold start time JITing these.
- I didn't love the way we build these kernels in the preprocessor, it was kind of hard to follow and I didn't feel like adding it back ;)
Do you foresee speedups as well for these going down this route?
If I understand your question correctly - do we plan to improve these kernels and will JITing them make it hard to do? I guess maybe.. if we get to that point and having them in the JIT is too annoying we can always revisit. (It's also pretty easy to add an instantiation which does not get included in the Metal library but is useful just for compile-time compilation.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand your question correctly
No I meant the exact opposite. These kernels feel very natural to be jitted. We could even imagine taking advantage of the fact that we are jitting to make on the fly kernels specific to index shapes for instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes very good point! In fact I had considered doing specializations in the past for different index layouts but it was too unwieldy from a combinatorial perspective
@@ -5,9 +5,11 @@ | |||
#include <metal_atomic> | |||
#include <metal_simdgroup> | |||
|
|||
#ifndef MLX_METAL_JIT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a bit annoying. I think we 'll waste a fair amount of time forgetting to do that. I don't have anything better to propose... just commenting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea it's annoying. Getting the includes right is annoying in general.. perhaps the most annoying part of all of this. There is probably a way to avoid this but it requires some care in what you include where and in what order (which is quite brittle).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was able to remove a few of these and I think I can remove the rest in #1132.
MLX_METAL_JIT
to reduce the Metal library size by using runtime compilation.15M mlx.metallib