Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FFI]: Add JEP for FFI #12632

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
338 changes: 338 additions & 0 deletions docs/jep/12535-ffi.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,338 @@
# JAX Foreign Function Interface

* Authors: Sharad Vikram, Kuangyuan Chen, Qiao Zhang
* Date: October 3, 2022

## tl;dr

We propose a new API for user foreign functions calls, eliminating a lot of currently needed boilerplate and providing a tighter integration with JAX.

### Example usage on CPU

User-written C++ FFI function (exposed via `pybind11`)
```cpp
#include <pybind11/pybind11.h>
#include "jax_ffi.h" // Header that comes from JAX

namespace py = pybind11;

struct Info {
float n;
};

extern "C" void add_n(JaxFFI_API* api, JaxFFIStatus* status,
void* static_info, void** inputs, void** outputs) {
Info* info = (Info*) static_info;
if (info->n < 0) {
JaxFFIStatusSetFailure(api, status, "Info must be >= 0");
return;
}
float* input = (float*) inputs[0];
float* output = (float*) outputs[0];
*output = *input + info->n;
}

PYBIND11_MODULE(add_n_lib, m) {
m.def("get_function", []() {
return py::capsule(reinterpret_cast<void*>(add_n), JAX_FFI_CALL_CPU);
});
m.def("get_static_info", [](float n) {
Info* info = new Info();
info->n = n;
return py::capsule(reinterpret_cast<void*>(info), [](void* ptr) {
delete reinterpret_cast<Info*>(ptr);
});
});
}
```

JAX registration code and usage
```python
import jax
import add_n_lib

add_n = jax.ffi.FFICall("add_n")
add_n.register_function(add_n_lib.get_function(), platform="cpu")

def f(x):
static_info = add_n_lib.get_static_info(4.)
return add_n(x, static_info=static_info, out_shape_dtype=x)

print(jax.jit(f)(4.))
```

## Motivation
Our motivation is three-fold:

* We’d like to make JAX’s APIs for defining foreign calls simpler. Currently, we have no official API and defining calls to foreign functions involves using XLA’s custom call API and creating JAX primitives that wrap the custom calls. Not only does this involve a lot of boilerplate to hook XLA and JAX together, but it also uses private JAX APIs, making the code unstable with future JAX releases.
* We’d like to better integrate foreign calls with JAX’s system. Currently, the only way foreign calls can be exposed with JAX is via creating custom primitives, which is not a stable integration point. By wrapping both the XLA custom call registration and the JAX primitive registration, we can provide a stable API and provide integration points with the rest of JAX, e.g. custom gradients, batching, and error handling.
* We’d like to have centralized documentation on how to extend JAX with foreign function calls. Currently users need to combine documentation from multiple sources (XLA documentation, JAX primitive documentation). Centralizing documentation will not only prevent user confusion but lower the barrier to using foreign functions. At the moment, nothing is documented on the JAX site itself, leading people to write and use external guides, namely [the dfm guide](https://github.com/dfm/extending-jax).

Desiderata:
* We’d like the system to be as expressive as XLA’s current custom call machinery. We don’t want to limit what sorts of programs the user can write, forcing them to use the underlying private APIs.
* We’d like users to avoid having to learn about XLA’s custom call machinery and JAX’s primitive internals. For one, it adds a lot of mental overhead, but also we’d like a new FFI API to be both agnostic to 1) the choice of backend compiler and 2) JAX’s changing internals.

## Background

JAX uses XLA to compile staged-out Python programs, which are represented with MHLO. MHLO offers common operations used in machine learning, such as the dot-product, exponential function, and so on. Unfortunately, sometimes these operations (and compositions of these operations) are not sufficiently expressive or performant and users look to external libraries (Triton, Numba) or hand-written code (CUDA, C/C++) to implement operations for their desired program.

The mechanism XLA offers to write custom ops is the XLA `CustomCall`. Its API is general purpose and JAX uses it for host callbacks and custom GPU/CPU library ops (`lapack`, `cublas`, `ducc`). To register your own custom op, the steps are (at a high level):
1. Compile a Python extension that exposes your code as a capsule (e.g. via `pybind11`, Cython).
2. Import that Python extension in your Python code and register it with JAX’s `xla_extension` (via `jax._src.lib.xla_client`) with a name for the custom call.
3. Create a new JAX primitive and register a MLIR lowering that emits a custom call with a matching name.

Note that in order to do this, users need to be familiar with 1) the XLA custom call API (documented separately from JAX), 2) XLA’s Python extension API (how to register custom calls), and 3) JAX’s primitive API (which is internal and subject to change) and 4) the MHLO/MLIR python bindings.

There are also many details associated with each of these steps, outlined in [the dfm guide](https://github.com/dfm/extending-jax). Note that the dfm guide uses the out-of-date XLA builder, not MLIR builder.

## Technical challenges
JAX FFI fundamentally offers a C API to JAX users to extend the JAX system. Concretely, JAX users write a C/C++ custom kernel that needs to be registered with the XLA custom call registry. Further, users may want to use C helpers from JAX that for example report errors in a way that JAX can consume. We observe these three main challenges:
1. Avoiding building `jaxlib` – we do not want users to rebuild `jaxlib` because of possible version mismatch with JAX and also the notoriously long build time for TensorFlow/XLA subtree.
2. Registering a function (C function for ABI compatibility) defined in a user shared object with the XLA custom call registry defined in `jaxlib` shared object without running into duplicate global objects (XLA custom call registry is a global variable) issues
3. Referencing helper functions defined in `jaxlib` in user defined shared objects

At a high level, all three challenges are about linking and symbol resolution in the presence of more than one shared object.

## Proposal

Here we go over the various parts of the proposed API, starting with how users expose their foreign functions to Python and ending with how those foreign functions are registered with and used in JAX.

### User foreign function API
We propose the following API for user-written foreign CPU functions (with an additional `CUstream` argument for GPU):
```c
#include "jax_ffi.h"
void user_function(JaxFFI_API* api, JaxFFIStatus* status, void* static_info, void** outputs, void** inputs) {
...
}
```

This roughly mirrors the XLA custom call API, except we provide `api`, which contains helper functions. `status` is used to indicate that there was an error in the computation like `XlaCustomCallStatus`. We also provide `static_info`, which contains compile-time/lowering-time information from JAX. The JaxFFI types will be exposed in a `jax_ffi.h` header that is shipped with `jaxlib`. Note that in this user code example, we do not reference any TensorFlow/XLA headers, meaning that the user only needs the `jax_ffi.h` header that will be shipped with `jaxlib` and users do not need to rebuild `jaxlib`, addressing challenge #1.

We also propose APIs that avoid some potentially unnecessary arguments (`api`, `status`, `static_info`).

Challenge #3 is about allowing users to find the helper functions exposed by `jaxlib`. Instead of the typical dynamic runtime linking (e.g., via `dlopen`), we offer a solution similar to Numpy C API linking.

`jaxlib` needs to expose a few helper functions that user kernels can invoke. We implement these functions in a new file `jax_ffi.c`. To help symbol resolution, we store a pointer array of C function pointers called `JaxFFI_API_Table`:

```c
struct JaxFFIStatus {
std::optional<std::string> message;
};

int JAX_FFI_V1 = 1;

int JaxFFIVersionFn() {
return JAX_FFI_V1;
}

void JaxFFIStatusSetFailureFn(JaxFFIStatus* status, const char* message) {
status->message = std::string(message);
}

void *JaxFFI_API_Table[] = {
(void *)&JaxFFIVersionFn,
(void *)&JaxFFIStatusSetFailureFn
};
```

When jaxlib eventually invokes the user function, jaxlib will pass in the pointer array as an argument explicitly as `JaxFFI_API* api` (note that `jaxlib` stores the pointer array). We then provide macros in `jax_ffi.h` as convenience methods to index into the pointer array and find the appropriate helper function:

```c
#define JAX_FFI_CALL_CPU "JAX_FFI_CALL_CPU"

#define JaxFFIVersion() \
((*((void (*)())(api[0])))())
#define JaxFFIStatusSetFailure(api, status, msg) \
((*((void (*)(JaxFFIStatus*, const char*))(api[1])))(status, msg))

struct JaxFFIStatus;

typedef void* JaxFFI_API;
```


### Exposing foreign functions to JAX

To address challenge #2, namely that of needing to register in the XLA extension shared object, we require the user to expose the foreign function to Python. This can be done in a variety of ways, but fundamentally we need to produce Python bindings that expose the function pointer `&user_function`. The function pointer can be handled in Python opaquely via a `PyCapsule`. Here’s an example using `pybind11`:

```cpp
#include <pybind11/pybind11.h>
#include "jax_ffi.h"

void user_function(JaxFFI_API* api, JaxFFIStatus* status, void* static_info, void** outputs, void** inputs) {
...
}

PYBIND11_MODULE(user_function_lib, m) {
m.def("get_function", []() {
return py::capsule(reinterpret_cast<void*>(user_function), JAX_FFI_CALL_CPU);
});
}
```

The capsule should be given a name to both indicate the type of signature (CPU vs GPU) and to throw errors early if the wrong function is registered.

### Passing in custom descriptors
Foreign functions may often need more than just “runtime information” like the values of input arrays. “Static information” that is provided in JAX at tracing/lowering/compile time also needs to be passed into the foreign function. XLA (currently) offers two separate mechanisms for providing this static information to custom calls.

On CPU, there is no official mechanism but this information can often be provided by passing a pointer value as an argument to the custom call, which points to a heap allocated object. Inside of the custom call, the pointer can be dereferenced to get the object and access its information.

On GPU, the custom call API offers opaque, a string that will be passed to the custom call. This requires that the information to be passed to the custom call needs to be serializable. Note that we can also “sneak” pointer values in the opaque string, allowing us to pass heap allocated objects as well.

From the user perspective, these details are unnecessary and can be handled internally by JAX. The user should have a single API for passing this static information into a custom call.

Suppose the user wants to pass a struct Info into their foreign function. In order to do so, the Info struct (or pointers to it) need to be available to Python so the JAX can construct MHLO that passes it back into the custom call.

#### Exposing static information to Python via a pointer (pass by reference)

```cpp
#include <pybind11/pybind11.h>
#include "jax_ffi.h"

struct Info {
float n;
};

void user_function(JaxFFI_API* api, JaxFFI_Status* status, void* static_info, void** outputs, void** inputs) {
Info info = (Info*) descriptor;
...
}

PYBIND11_MODULE(user_function_lib, m) {
m.def("get_function", []() {
return py::capsule(reinterpret_cast<void*>(user_function), JAX_FFI_CALL_CPU);
});
m.def("make_info", [](float n) {
Info* info = new Info();
info->n = n;
return py::capsule(reinterpret_cast<void*>(info), [](void* ptr) {
delete reinterpret_cast<Info*>(ptr);
});
});
}
```

This approach wraps a heap allocated object in a capsule, and destroys the object when the capsule object is destroyed. This hands the ownership of the object to Python. JAX will then take ownership of the object and give it to the executable, like it does with other capsule objects. This is how JAX handles host callbacks.

#### Exposing static information to Python via serialization (pass by value)

```cpp
m.def("make_info", [](float n) {
Info info;
info.n = n;
return pybind11::bytes(reinterpret_cast<const char*>(&info), sizeof(Info));
});
```


This approach serializes the struct as a string, then returns it to Python as a bytes object. Since we’re not doing any heap allocation, we don’t need to worry about ownership and don’t require JAX to keep a heap object alive. This is how JAX handles custom calls to external libraries like cublas, lapack, and ducc.

JAX should handle both cases (pass by reference and value) and pass the appropriate pointer back into the user foreign function.

### Handling foreign functions and descriptors in Python

We’ve shown how users expose foreign functions to Python. Now we’ll show how users register and use these functions with JAX.

#### Registering FFI calls in JAX
First we introduce a new JAX module, `jax.ffi`. `jax.ffi` will expose a `jax.ffi.FFICall` object.

```python
Platform = FunctionPointer = Any

class FFICall:
name: str
_registry: Dict[Platform, FunctionPointer]

def register_function(self, function_ptr, *, platform):
...

def __call__(self, *args, **kwargs):
...
```

We can construct `FFICall`s with a string name that uniquely identifies them (we should error if the same name is used twice).

```python
import jax.ffi

user_function = jax.ffi.FFICall("user_function")
```

We allow users to register platform-specific implementations for the FFI call.
```python
import user_function_lib # the Python extension
user_function.register_function(user_function_lib.get_function(), platform="cpu")
```

This allows users to write a CPU version of their code and a GPU version as well.
#### Calling foreign functions from JAX

`FFICall` objects have a `__call__` method that invokes a JAX primitive, `jax_ffi_call`, that has already registered transformation rules.

```python
@jax.jit
def f(...):
... = user_function(..., return_shape_dtype=...)
```

The user needs to provide a `return_shape_dtype` information since that can’t be inferred by JAX and JAX requires statically known shapes and dtypes for all values.

To pass in a descriptor as well, users can construct static information and pass it into the `user_function` via a reserved keyword argument `static_info`.

```python
@jax.jit
def f(...):
static_info = user_function_lib.make_info(4.)
... = user_function(..., static_info=static_info, return_shape_dtype=...)
```

#### JAX custom call wrapper

When the user eventually calls the `FFICall`, we emit a specific MHLO custom call (`jax_ffi_call`) during lowering. This custom call is already registered and is passed both the function pointer capsule (registered earlier) and the static info (passed into the primitive). It then prepares the `api` and `status` and calls the function pointer along with `api` and `status` with the input/output pointers.

```cpp
struct Descriptor {
void* function_ptr;
void* user_static_info;
};

extern "C" void JaxFFICallWrapper(void* output, void** inputs,
XlaCustomCallStatus* status) {
auto descriptor = reinterpret_cast<Descriptor*>(*static_cast<uintptr_t*>(inputs[0]));
inputs += 1;
JaxFFIStatus jax_ffi_status;
auto function_ptr = reinterpret_cast<void (*) (JaxFFI_API*, JaxFFIStatus*,
void*, void**, void**)>(descriptor->function_ptr);
function_ptr(JaxFFI_API_Table, &jax_ffi_status,
descriptor->user_static_info,
inputs, reinterpret_cast<void**>(output));
if (jax_ffi_status.message) {
// Handle error!
}
}
```


On CPU (for now), the `function_ptr` and `user_static_info` are passed by reference as the first argument to the custom call. On GPU, they can be passed by reference via the opaque string.

Note that XLA custom calls support custom layouts for operands and results. Here we’ll generate MHLO that uses default layouts, which technically limits what users can express.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why you do not try to cover this?
Or this is planned for a next version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our thinking is:

  • We can't support this easily without either implementing our own wrapper around XLA's layout API or exposing details of XLA itself
  • Users can work around by either changing layouts manually or forgoing the FFI API and using the Custom Call API directly

In the long run, we might want to support custom layouts and that can be in a follow up version.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the other hand, we could just provide an API to specify custom major-to-minor orders in Python (like [0, 1] or [1, 0]) and that would probably be sufficient.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. Asking to not use the FFI just for that looks a very high cost. So allowing this would be great.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Peter pointed out that jaxlib uses custom layouts frequently so we should support this. Thankfully it only needs to exist in Python and doesn't require changing the C API.


### Handling JAX transformations

Unlike in the dfm guide, users are not constructing JAX primitives and therefore don’t have the opportunity to register transformation rules for those primitives. Do we want to expose them and if so, how?

For automatic differentiation, users have the option of wrapping their `FFICall` with a `jax.custom_jvp` or `jax.custom_vjp`. Alternatively we could expose additional methods on `FFICall` that do something similar. The `jax.custom_*` (`custom_vmap`, `custom_transpose`, etc.) API, in principle, could also handle any custom behavior users want from FFI calls. However, this API has not been fully built out yet.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about multi-GPU and sharding?
I think this case should be supported as we know this is a current issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an important use case but can arguably be solved orthogonally via the WIP custom partitioning API. We're still working on its design but presumably we can surface a jax.custom_sharding or something like it that will work with FFI calls.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be great.


For now, we propose not committing to a specific transformation API for custom calls and wait to see how the `custom_*` solution plays out. If users want very specific transformation behavior, they can rely on the (internal) primitive API, i.e. the status quo. Custom transformation behavior is orthogonal to the problem of enabling foreign function calls and fits into the larger discussion of how to expose custom primitives to users.

We could also consider some default transformation rules under vmap, for example. If the user promises their function is pure, we can adopt a strategy like `pure_callback` where we can sequentially map over the batch dimensions. If the function is not pure, we disable most transformations.

### Error handling

In the `JaxFFICallWrapper` above, we don’t explicitly say how we handle errors. Although we expose an API like that in XLA’s custom call, we don’t have to use XlaCustomCallStatusSetFailure, which has specific operational semantics. Instead, we can hook into the extensible error API described in the error handling JEP. Creating this layer of indirection allows us to have functional error handling in custom calls as well.

There are also a few different ways custom calls often fail. In `jaxlib`, custom calls will call `XlaCustomCallStatusSetFailure` usually when there is an unrecoverable failure (CUDA errors, OOMs, etc.). Arguably we shouldn’t handle these errors in JAX itself. Other sorts of errors, for example numerical errors in linear algebra routines, could be surfaced in JAX via the unified error handling API. We should consider having an extra bit (e.g. recoverable) in `JaxFFIStatusSetFailure` that distinguishes between these two types of errors.

## Implementation Plan

We provide a prototype of the proposed API in [this PR](https://github.com/google/jax/pull/12396).
1 change: 1 addition & 0 deletions docs/jep/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Then create a pull request that adds a file named
10657: Sequencing side-effects in JAX <10657-sequencing-effects>
11830: `jax.remat` / `jax.checkpoint` new implementation <11830-new-remat-checkpoint>
12049: Type Annotation Roadmap for JAX <12049-type-annotations>
12535: Foreign Function Interface <12535-ffi>


Several early JEPs were converted in hindsight from other documentation,
Expand Down