-
Notifications
You must be signed in to change notification settings - Fork 183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding MLX backend. #419
Adding MLX backend. #419
Conversation
We can safely ignore the failure which is entirely unrelated to the PR (more recent clippy detects an issue with PyO3 itself, which will most likely be fixed, by simply upgrading PyO3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dope! Thanks for putting this up! 🚀 Conceptually, this looks good to me.
That said, I'm not an expert in MLX so would appreciate if Awni could do a deeper review.
One qq: Can we potentially also test for persisting and loading quantised models as well? we don't need to be exhaustive here, but just testing for 4-bit should be okay IMO.
Quantized is u8, therefore it should work out of the box (not sure how mlx persists quantized information in npx/npz, i'm guessing it's not saved) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! 🔥
return numpy_dict | ||
|
||
|
||
def _mx2np(mx_dict: Dict[str, mx.array]) -> Dict[str, np.array]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this handle bfloat16
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see in the tests below that it does :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, using the same flax trick which is just a special named dtype.
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for adding that!
I think using your built-in numpy.save
is probably a good call from a maintenance standpoint.
But just a thought that this will involve a copy (on the load side as well). Currently we do a copy to get in and out of Numpy.
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Good to know. The only thing really needed is If there's a copy from numpy-> MLX, that means that currently there are 2 copies created. File -> Slice (local buffer inside rust) -> numpy.from_buffer (should be zero copy) -> mlx (new copy). Is that correct ? |
Yes that's correct. We can't avoid the copy from NumPy because we have to use specifically allocated memory to make it available to both the CPU and GPU. |
What does this PR do?
Adds MLX backend.
MLX already has native safetensors support, but this was trivial to add.
Fixes # (issue) or description of the problem this PR solves.