Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding MLX backend. #419

Merged
merged 4 commits into from
Jan 5, 2024
Merged

Adding MLX backend. #419

merged 4 commits into from
Jan 5, 2024

Conversation

Narsil
Copy link
Collaborator

@Narsil Narsil commented Jan 4, 2024

What does this PR do?

Adds MLX backend.
MLX already has native safetensors support, but this was trivial to add.

Fixes # (issue) or description of the problem this PR solves.

@Narsil
Copy link
Collaborator Author

Narsil commented Jan 4, 2024

We can safely ignore the failure which is entirely unrelated to the PR (more recent clippy detects an issue with PyO3 itself, which will most likely be fixed, by simply upgrading PyO3)

Copy link
Member

@Vaibhavs10 Vaibhavs10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dope! Thanks for putting this up! 🚀 Conceptually, this looks good to me.

That said, I'm not an expert in MLX so would appreciate if Awni could do a deeper review.

One qq: Can we potentially also test for persisting and loading quantised models as well? we don't need to be exhaustive here, but just testing for 4-bit should be okay IMO.

@Narsil
Copy link
Collaborator Author

Narsil commented Jan 4, 2024

Quantized is u8, therefore it should work out of the box (not sure how mlx persists quantized information in npx/npz, i'm guessing it's not saved)

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! 🔥

bindings/python/py_src/safetensors/mlx.py Outdated Show resolved Hide resolved
bindings/python/py_src/safetensors/mlx.py Outdated Show resolved Hide resolved
bindings/python/py_src/safetensors/mlx.py Outdated Show resolved Hide resolved
bindings/python/py_src/safetensors/mlx.py Outdated Show resolved Hide resolved
bindings/python/py_src/safetensors/mlx.py Outdated Show resolved Hide resolved
bindings/python/py_src/safetensors/mlx.py Outdated Show resolved Hide resolved
bindings/python/py_src/safetensors/mlx.py Outdated Show resolved Hide resolved
return numpy_dict


def _mx2np(mx_dict: Dict[str, mx.array]) -> Dict[str, np.array]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this handle bfloat16?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see in the tests below that it does :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, using the same flax trick which is just a special named dtype.

bindings/python/tests/test_mlx_comparison.py Outdated Show resolved Hide resolved
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Copy link

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for adding that!

I think using your built-in numpy.save is probably a good call from a maintenance standpoint.

But just a thought that this will involve a copy (on the load side as well). Currently we do a copy to get in and out of Numpy.

Narsil and others added 2 commits January 5, 2024 14:33
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
@Narsil
Copy link
Collaborator Author

Narsil commented Jan 5, 2024

But just a thought that this will involve a copy (on the load side as well). Currently we do a copy to get in and out of Numpy.

Good to know. The only thing really needed is frombuffer equivalent in order to get loading partial tensors working.
On the rust side, I use the slice information given by the user to create. a CPU buffer of the correct data within the tensor, and use frombuffer on given framework in order to send the correct data where it belongs.

If there's a copy from numpy-> MLX, that means that currently there are 2 copies created.

File -> Slice (local buffer inside rust) -> numpy.from_buffer (should be zero copy) -> mlx (new copy).

Is that correct ?

@Narsil Narsil merged commit 56659f4 into main Jan 5, 2024
5 of 11 checks passed
@Narsil Narsil deleted the add_mlx branch January 5, 2024 13:36
@awni
Copy link

awni commented Jan 6, 2024

If there's a copy from numpy-> MLX, that means that currently there are 2 copies created.

Yes that's correct. We can't avoid the copy from NumPy because we have to use specifically allocated memory to make it available to both the CPU and GPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants