-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FFI]: Add JEP for FFI #12632
[FFI]: Add JEP for FFI #12632
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,338 @@ | ||
# JAX Foreign Function Interface | ||
|
||
* Authors: Sharad Vikram, Kuangyuan Chen, Qiao Zhang | ||
* Date: October 3, 2022 | ||
|
||
## tl;dr | ||
|
||
We propose a new API for user foreign functions calls, eliminating a lot of currently needed boilerplate and providing a tighter integration with JAX. | ||
|
||
### Example usage on CPU | ||
|
||
User-written C++ FFI function (exposed via `pybind11`) | ||
```cpp | ||
#include <pybind11/pybind11.h> | ||
#include "jax_ffi.h" // Header that comes from JAX | ||
|
||
namespace py = pybind11; | ||
|
||
struct Info { | ||
float n; | ||
}; | ||
|
||
extern "C" void add_n(JaxFFI_API* api, JaxFFIStatus* status, | ||
void* static_info, void** inputs, void** outputs) { | ||
Info* info = (Info*) static_info; | ||
if (info->n < 0) { | ||
JaxFFIStatusSetFailure(api, status, "Info must be >= 0"); | ||
return; | ||
} | ||
float* input = (float*) inputs[0]; | ||
float* output = (float*) outputs[0]; | ||
*output = *input + info->n; | ||
} | ||
|
||
PYBIND11_MODULE(add_n_lib, m) { | ||
m.def("get_function", []() { | ||
return py::capsule(reinterpret_cast<void*>(add_n), JAX_FFI_CALL_CPU); | ||
}); | ||
m.def("get_static_info", [](float n) { | ||
Info* info = new Info(); | ||
info->n = n; | ||
return py::capsule(reinterpret_cast<void*>(info), [](void* ptr) { | ||
delete reinterpret_cast<Info*>(ptr); | ||
}); | ||
}); | ||
} | ||
``` | ||
|
||
JAX registration code and usage | ||
```python | ||
import jax | ||
import add_n_lib | ||
|
||
add_n = jax.ffi.FFICall("add_n") | ||
add_n.register_function(add_n_lib.get_function(), platform="cpu") | ||
|
||
def f(x): | ||
static_info = add_n_lib.get_static_info(4.) | ||
return add_n(x, static_info=static_info, out_shape_dtype=x) | ||
|
||
print(jax.jit(f)(4.)) | ||
``` | ||
|
||
## Motivation | ||
Our motivation is three-fold: | ||
|
||
* We’d like to make JAX’s APIs for defining foreign calls simpler. Currently, we have no official API and defining calls to foreign functions involves using XLA’s custom call API and creating JAX primitives that wrap the custom calls. Not only does this involve a lot of boilerplate to hook XLA and JAX together, but it also uses private JAX APIs, making the code unstable with future JAX releases. | ||
* We’d like to better integrate foreign calls with JAX’s system. Currently, the only way foreign calls can be exposed with JAX is via creating custom primitives, which is not a stable integration point. By wrapping both the XLA custom call registration and the JAX primitive registration, we can provide a stable API and provide integration points with the rest of JAX, e.g. custom gradients, batching, and error handling. | ||
* We’d like to have centralized documentation on how to extend JAX with foreign function calls. Currently users need to combine documentation from multiple sources (XLA documentation, JAX primitive documentation). Centralizing documentation will not only prevent user confusion but lower the barrier to using foreign functions. At the moment, nothing is documented on the JAX site itself, leading people to write and use external guides, namely [the dfm guide](https://github.com/dfm/extending-jax). | ||
|
||
Desiderata: | ||
* We’d like the system to be as expressive as XLA’s current custom call machinery. We don’t want to limit what sorts of programs the user can write, forcing them to use the underlying private APIs. | ||
* We’d like users to avoid having to learn about XLA’s custom call machinery and JAX’s primitive internals. For one, it adds a lot of mental overhead, but also we’d like a new FFI API to be both agnostic to 1) the choice of backend compiler and 2) JAX’s changing internals. | ||
|
||
## Background | ||
|
||
JAX uses XLA to compile staged-out Python programs, which are represented with MHLO. MHLO offers common operations used in machine learning, such as the dot-product, exponential function, and so on. Unfortunately, sometimes these operations (and compositions of these operations) are not sufficiently expressive or performant and users look to external libraries (Triton, Numba) or hand-written code (CUDA, C/C++) to implement operations for their desired program. | ||
|
||
The mechanism XLA offers to write custom ops is the XLA `CustomCall`. Its API is general purpose and JAX uses it for host callbacks and custom GPU/CPU library ops (`lapack`, `cublas`, `ducc`). To register your own custom op, the steps are (at a high level): | ||
1. Compile a Python extension that exposes your code as a capsule (e.g. via `pybind11`, Cython). | ||
2. Import that Python extension in your Python code and register it with JAX’s `xla_extension` (via `jax._src.lib.xla_client`) with a name for the custom call. | ||
3. Create a new JAX primitive and register a MLIR lowering that emits a custom call with a matching name. | ||
|
||
Note that in order to do this, users need to be familiar with 1) the XLA custom call API (documented separately from JAX), 2) XLA’s Python extension API (how to register custom calls), and 3) JAX’s primitive API (which is internal and subject to change) and 4) the MHLO/MLIR python bindings. | ||
|
||
There are also many details associated with each of these steps, outlined in [the dfm guide](https://github.com/dfm/extending-jax). Note that the dfm guide uses the out-of-date XLA builder, not MLIR builder. | ||
|
||
## Technical challenges | ||
JAX FFI fundamentally offers a C API to JAX users to extend the JAX system. Concretely, JAX users write a C/C++ custom kernel that needs to be registered with the XLA custom call registry. Further, users may want to use C helpers from JAX that for example report errors in a way that JAX can consume. We observe these three main challenges: | ||
1. Avoiding building `jaxlib` – we do not want users to rebuild `jaxlib` because of possible version mismatch with JAX and also the notoriously long build time for TensorFlow/XLA subtree. | ||
2. Registering a function (C function for ABI compatibility) defined in a user shared object with the XLA custom call registry defined in `jaxlib` shared object without running into duplicate global objects (XLA custom call registry is a global variable) issues | ||
3. Referencing helper functions defined in `jaxlib` in user defined shared objects | ||
|
||
At a high level, all three challenges are about linking and symbol resolution in the presence of more than one shared object. | ||
|
||
## Proposal | ||
|
||
Here we go over the various parts of the proposed API, starting with how users expose their foreign functions to Python and ending with how those foreign functions are registered with and used in JAX. | ||
|
||
### User foreign function API | ||
We propose the following API for user-written foreign CPU functions (with an additional `CUstream` argument for GPU): | ||
```c | ||
#include "jax_ffi.h" | ||
void user_function(JaxFFI_API* api, JaxFFIStatus* status, void* static_info, void** outputs, void** inputs) { | ||
... | ||
} | ||
``` | ||
|
||
This roughly mirrors the XLA custom call API, except we provide `api`, which contains helper functions. `status` is used to indicate that there was an error in the computation like `XlaCustomCallStatus`. We also provide `static_info`, which contains compile-time/lowering-time information from JAX. The JaxFFI types will be exposed in a `jax_ffi.h` header that is shipped with `jaxlib`. Note that in this user code example, we do not reference any TensorFlow/XLA headers, meaning that the user only needs the `jax_ffi.h` header that will be shipped with `jaxlib` and users do not need to rebuild `jaxlib`, addressing challenge #1. | ||
|
||
We also propose APIs that avoid some potentially unnecessary arguments (`api`, `status`, `static_info`). | ||
|
||
Challenge #3 is about allowing users to find the helper functions exposed by `jaxlib`. Instead of the typical dynamic runtime linking (e.g., via `dlopen`), we offer a solution similar to Numpy C API linking. | ||
|
||
`jaxlib` needs to expose a few helper functions that user kernels can invoke. We implement these functions in a new file `jax_ffi.c`. To help symbol resolution, we store a pointer array of C function pointers called `JaxFFI_API_Table`: | ||
|
||
```c | ||
struct JaxFFIStatus { | ||
std::optional<std::string> message; | ||
}; | ||
|
||
int JAX_FFI_V1 = 1; | ||
|
||
int JaxFFIVersionFn() { | ||
return JAX_FFI_V1; | ||
} | ||
|
||
void JaxFFIStatusSetFailureFn(JaxFFIStatus* status, const char* message) { | ||
status->message = std::string(message); | ||
} | ||
|
||
void *JaxFFI_API_Table[] = { | ||
(void *)&JaxFFIVersionFn, | ||
(void *)&JaxFFIStatusSetFailureFn | ||
}; | ||
``` | ||
|
||
When jaxlib eventually invokes the user function, jaxlib will pass in the pointer array as an argument explicitly as `JaxFFI_API* api` (note that `jaxlib` stores the pointer array). We then provide macros in `jax_ffi.h` as convenience methods to index into the pointer array and find the appropriate helper function: | ||
|
||
```c | ||
#define JAX_FFI_CALL_CPU "JAX_FFI_CALL_CPU" | ||
|
||
#define JaxFFIVersion() \ | ||
((*((void (*)())(api[0])))()) | ||
#define JaxFFIStatusSetFailure(api, status, msg) \ | ||
((*((void (*)(JaxFFIStatus*, const char*))(api[1])))(status, msg)) | ||
|
||
struct JaxFFIStatus; | ||
|
||
typedef void* JaxFFI_API; | ||
``` | ||
|
||
|
||
### Exposing foreign functions to JAX | ||
|
||
To address challenge #2, namely that of needing to register in the XLA extension shared object, we require the user to expose the foreign function to Python. This can be done in a variety of ways, but fundamentally we need to produce Python bindings that expose the function pointer `&user_function`. The function pointer can be handled in Python opaquely via a `PyCapsule`. Here’s an example using `pybind11`: | ||
|
||
```cpp | ||
#include <pybind11/pybind11.h> | ||
#include "jax_ffi.h" | ||
|
||
void user_function(JaxFFI_API* api, JaxFFIStatus* status, void* static_info, void** outputs, void** inputs) { | ||
... | ||
} | ||
|
||
PYBIND11_MODULE(user_function_lib, m) { | ||
m.def("get_function", []() { | ||
return py::capsule(reinterpret_cast<void*>(user_function), JAX_FFI_CALL_CPU); | ||
}); | ||
} | ||
``` | ||
|
||
The capsule should be given a name to both indicate the type of signature (CPU vs GPU) and to throw errors early if the wrong function is registered. | ||
|
||
### Passing in custom descriptors | ||
Foreign functions may often need more than just “runtime information” like the values of input arrays. “Static information” that is provided in JAX at tracing/lowering/compile time also needs to be passed into the foreign function. XLA (currently) offers two separate mechanisms for providing this static information to custom calls. | ||
|
||
On CPU, there is no official mechanism but this information can often be provided by passing a pointer value as an argument to the custom call, which points to a heap allocated object. Inside of the custom call, the pointer can be dereferenced to get the object and access its information. | ||
|
||
On GPU, the custom call API offers opaque, a string that will be passed to the custom call. This requires that the information to be passed to the custom call needs to be serializable. Note that we can also “sneak” pointer values in the opaque string, allowing us to pass heap allocated objects as well. | ||
|
||
From the user perspective, these details are unnecessary and can be handled internally by JAX. The user should have a single API for passing this static information into a custom call. | ||
|
||
Suppose the user wants to pass a struct Info into their foreign function. In order to do so, the Info struct (or pointers to it) need to be available to Python so the JAX can construct MHLO that passes it back into the custom call. | ||
|
||
#### Exposing static information to Python via a pointer (pass by reference) | ||
|
||
```cpp | ||
#include <pybind11/pybind11.h> | ||
#include "jax_ffi.h" | ||
|
||
struct Info { | ||
float n; | ||
}; | ||
|
||
void user_function(JaxFFI_API* api, JaxFFI_Status* status, void* static_info, void** outputs, void** inputs) { | ||
Info info = (Info*) descriptor; | ||
... | ||
} | ||
|
||
PYBIND11_MODULE(user_function_lib, m) { | ||
m.def("get_function", []() { | ||
return py::capsule(reinterpret_cast<void*>(user_function), JAX_FFI_CALL_CPU); | ||
}); | ||
m.def("make_info", [](float n) { | ||
Info* info = new Info(); | ||
info->n = n; | ||
return py::capsule(reinterpret_cast<void*>(info), [](void* ptr) { | ||
delete reinterpret_cast<Info*>(ptr); | ||
}); | ||
}); | ||
} | ||
``` | ||
|
||
This approach wraps a heap allocated object in a capsule, and destroys the object when the capsule object is destroyed. This hands the ownership of the object to Python. JAX will then take ownership of the object and give it to the executable, like it does with other capsule objects. This is how JAX handles host callbacks. | ||
|
||
#### Exposing static information to Python via serialization (pass by value) | ||
|
||
```cpp | ||
m.def("make_info", [](float n) { | ||
Info info; | ||
info.n = n; | ||
return pybind11::bytes(reinterpret_cast<const char*>(&info), sizeof(Info)); | ||
}); | ||
``` | ||
|
||
|
||
This approach serializes the struct as a string, then returns it to Python as a bytes object. Since we’re not doing any heap allocation, we don’t need to worry about ownership and don’t require JAX to keep a heap object alive. This is how JAX handles custom calls to external libraries like cublas, lapack, and ducc. | ||
|
||
JAX should handle both cases (pass by reference and value) and pass the appropriate pointer back into the user foreign function. | ||
|
||
### Handling foreign functions and descriptors in Python | ||
|
||
We’ve shown how users expose foreign functions to Python. Now we’ll show how users register and use these functions with JAX. | ||
|
||
#### Registering FFI calls in JAX | ||
First we introduce a new JAX module, `jax.ffi`. `jax.ffi` will expose a `jax.ffi.FFICall` object. | ||
|
||
```python | ||
Platform = FunctionPointer = Any | ||
|
||
class FFICall: | ||
name: str | ||
_registry: Dict[Platform, FunctionPointer] | ||
|
||
def register_function(self, function_ptr, *, platform): | ||
... | ||
|
||
def __call__(self, *args, **kwargs): | ||
... | ||
``` | ||
|
||
We can construct `FFICall`s with a string name that uniquely identifies them (we should error if the same name is used twice). | ||
|
||
```python | ||
import jax.ffi | ||
|
||
user_function = jax.ffi.FFICall("user_function") | ||
``` | ||
|
||
We allow users to register platform-specific implementations for the FFI call. | ||
```python | ||
import user_function_lib # the Python extension | ||
user_function.register_function(user_function_lib.get_function(), platform="cpu") | ||
``` | ||
|
||
This allows users to write a CPU version of their code and a GPU version as well. | ||
#### Calling foreign functions from JAX | ||
|
||
`FFICall` objects have a `__call__` method that invokes a JAX primitive, `jax_ffi_call`, that has already registered transformation rules. | ||
|
||
```python | ||
@jax.jit | ||
def f(...): | ||
... = user_function(..., return_shape_dtype=...) | ||
``` | ||
|
||
The user needs to provide a `return_shape_dtype` information since that can’t be inferred by JAX and JAX requires statically known shapes and dtypes for all values. | ||
|
||
To pass in a descriptor as well, users can construct static information and pass it into the `user_function` via a reserved keyword argument `static_info`. | ||
|
||
```python | ||
@jax.jit | ||
def f(...): | ||
static_info = user_function_lib.make_info(4.) | ||
... = user_function(..., static_info=static_info, return_shape_dtype=...) | ||
``` | ||
|
||
#### JAX custom call wrapper | ||
|
||
When the user eventually calls the `FFICall`, we emit a specific MHLO custom call (`jax_ffi_call`) during lowering. This custom call is already registered and is passed both the function pointer capsule (registered earlier) and the static info (passed into the primitive). It then prepares the `api` and `status` and calls the function pointer along with `api` and `status` with the input/output pointers. | ||
|
||
```cpp | ||
struct Descriptor { | ||
void* function_ptr; | ||
void* user_static_info; | ||
}; | ||
|
||
extern "C" void JaxFFICallWrapper(void* output, void** inputs, | ||
XlaCustomCallStatus* status) { | ||
auto descriptor = reinterpret_cast<Descriptor*>(*static_cast<uintptr_t*>(inputs[0])); | ||
inputs += 1; | ||
JaxFFIStatus jax_ffi_status; | ||
auto function_ptr = reinterpret_cast<void (*) (JaxFFI_API*, JaxFFIStatus*, | ||
void*, void**, void**)>(descriptor->function_ptr); | ||
function_ptr(JaxFFI_API_Table, &jax_ffi_status, | ||
descriptor->user_static_info, | ||
inputs, reinterpret_cast<void**>(output)); | ||
if (jax_ffi_status.message) { | ||
// Handle error! | ||
} | ||
} | ||
``` | ||
|
||
|
||
On CPU (for now), the `function_ptr` and `user_static_info` are passed by reference as the first argument to the custom call. On GPU, they can be passed by reference via the opaque string. | ||
|
||
Note that XLA custom calls support custom layouts for operands and results. Here we’ll generate MHLO that uses default layouts, which technically limits what users can express. | ||
|
||
### Handling JAX transformations | ||
|
||
Unlike in the dfm guide, users are not constructing JAX primitives and therefore don’t have the opportunity to register transformation rules for those primitives. Do we want to expose them and if so, how? | ||
|
||
For automatic differentiation, users have the option of wrapping their `FFICall` with a `jax.custom_jvp` or `jax.custom_vjp`. Alternatively we could expose additional methods on `FFICall` that do something similar. The `jax.custom_*` (`custom_vmap`, `custom_transpose`, etc.) API, in principle, could also handle any custom behavior users want from FFI calls. However, this API has not been fully built out yet. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about multi-GPU and sharding? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an important use case but can arguably be solved orthogonally via the WIP custom partitioning API. We're still working on its design but presumably we can surface a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would be great. |
||
|
||
For now, we propose not committing to a specific transformation API for custom calls and wait to see how the `custom_*` solution plays out. If users want very specific transformation behavior, they can rely on the (internal) primitive API, i.e. the status quo. Custom transformation behavior is orthogonal to the problem of enabling foreign function calls and fits into the larger discussion of how to expose custom primitives to users. | ||
|
||
We could also consider some default transformation rules under vmap, for example. If the user promises their function is pure, we can adopt a strategy like `pure_callback` where we can sequentially map over the batch dimensions. If the function is not pure, we disable most transformations. | ||
|
||
### Error handling | ||
|
||
In the `JaxFFICallWrapper` above, we don’t explicitly say how we handle errors. Although we expose an API like that in XLA’s custom call, we don’t have to use XlaCustomCallStatusSetFailure, which has specific operational semantics. Instead, we can hook into the extensible error API described in the error handling JEP. Creating this layer of indirection allows us to have functional error handling in custom calls as well. | ||
|
||
There are also a few different ways custom calls often fail. In `jaxlib`, custom calls will call `XlaCustomCallStatusSetFailure` usually when there is an unrecoverable failure (CUDA errors, OOMs, etc.). Arguably we shouldn’t handle these errors in JAX itself. Other sorts of errors, for example numerical errors in linear algebra routines, could be surfaced in JAX via the unified error handling API. We should consider having an extra bit (e.g. recoverable) in `JaxFFIStatusSetFailure` that distinguishes between these two types of errors. | ||
|
||
## Implementation Plan | ||
|
||
We provide a prototype of the proposed API in [this PR](https://github.com/google/jax/pull/12396). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why you do not try to cover this?
Or this is planned for a next version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our thinking is:
In the long run, we might want to support custom layouts and that can be in a follow up version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the other hand, we could just provide an API to specify custom major-to-minor orders in Python (like [0, 1] or [1, 0]) and that would probably be sufficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. Asking to not use the FFI just for that looks a very high cost. So allowing this would be great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Peter pointed out that jaxlib uses custom layouts frequently so we should support this. Thankfully it only needs to exist in Python and doesn't require changing the C API.