Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FFI]: Add JEP for FFI #12632

Closed
wants to merge 5 commits into from
Closed

[FFI]: Add JEP for FFI #12632

wants to merge 5 commits into from

Conversation

sharadmv
Copy link
Collaborator

@sharadmv sharadmv commented Oct 3, 2022

Proposes a new API for calling out to foreign functions that wraps XLA's CustomCall API and JAX's primitive API.
You can check out a prototype implementation of this JEP in #12396.

You can read the HTML version of this proposal here: https://jax--12632.org.readthedocs.build/en/12632/jep/12535-ffi.html.

cc: @dfm

Tracker: #12535

@sharadmv sharadmv requested a review from hawkinsp October 3, 2022 18:50
@sharadmv sharadmv changed the title [FFI]: Add JEP Proposal for FFI [FFI]: Add JEP for FFI Oct 3, 2022
@mjsML mjsML requested review from nouiz and nvcastet October 3, 2022 20:26
docs/jep/12535-ffi.md Outdated Show resolved Hide resolved
docs/jep/12535-ffi.md Outdated Show resolved Hide resolved
@dfm
Copy link
Collaborator

dfm commented Oct 4, 2022

@sharadmv — I haven't had a chance to really dig into this in detail, but I wanted to say that from my perspective this would be a huge quality of life improvement!

@sharadmv
Copy link
Collaborator Author

sharadmv commented Oct 4, 2022

Thank you for your feedback. Your guide has been and will continue to be a huge benefit to the community!

@mattjj
Copy link
Collaborator

mattjj commented Oct 5, 2022

@dfm we've been referring people to your tutorial for years (thank you!), and as good as it is, one founding objective of this JEP was to make such tutorials unnecessary (i.e. to make your life much easier). 😁

docs/jep/12535-ffi.md Outdated Show resolved Hide resolved

Unlike in the dfm guide, users are not constructing JAX primitives and therefore don’t have the opportunity to register transformation rules for those primitives. Do we want to expose them and if so, how?

For automatic differentiation, users have the option of wrapping their `FFICall` with a `jax.custom_jvp` or `jax.custom_vjp`. Alternatively we could expose additional methods on `FFICall` that do something similar. The `jax.custom_*` (`custom_vmap`, `custom_transpose`, etc.) API, in principle, could also handle any custom behavior users want from FFI calls. However, this API has not been fully built out yet.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about multi-GPU and sharding?
I think this case should be supported as we know this is a current issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an important use case but can arguably be solved orthogonally via the WIP custom partitioning API. We're still working on its design but presumably we can surface a jax.custom_sharding or something like it that will work with FFI calls.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be great.


On CPU (for now), the `function_ptr` and `user_static_info` are passed by reference as the first argument to the custom call. On GPU, they can be passed by reference via the opaque string.

Note that XLA custom calls support custom layouts for operands and results. Here we’ll generate MHLO that uses default layouts, which technically limits what users can express.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why you do not try to cover this?
Or this is planned for a next version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our thinking is:

  • We can't support this easily without either implementing our own wrapper around XLA's layout API or exposing details of XLA itself
  • Users can work around by either changing layouts manually or forgoing the FFI API and using the Custom Call API directly

In the long run, we might want to support custom layouts and that can be in a follow up version.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the other hand, we could just provide an API to specify custom major-to-minor orders in Python (like [0, 1] or [1, 0]) and that would probably be sufficient.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. Asking to not use the FFI just for that looks a very high cost. So allowing this would be great.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Peter pointed out that jaxlib uses custom layouts frequently so we should support this. Thankfully it only needs to exist in Python and doesn't require changing the C API.

sharadmv and others added 5 commits January 12, 2023 16:41
@LionSR
Copy link

LionSR commented Oct 17, 2023

If now (late 2023) I wanted to start integrating a custom op (such as those from cuSolver/Magma etc) into Jax and define autodiffs, should I wait for this PR to land?

@nouiz
Copy link
Collaborator

nouiz commented Oct 18, 2023

The most up to date doc for JAX custom operation on GPU is:
https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html

@dfm
Copy link
Collaborator

dfm commented Aug 2, 2024

Now that #21925 is merged I think we can close this: we have an FFI now!

https://jax.readthedocs.io/en/latest/ffi.html

@dfm dfm closed this Aug 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants