Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for FFI calls with side effects via ffi_call #23982

Merged
merged 1 commit into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions jax/_src/extend/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from jax._src import core
from jax._src import dispatch
from jax._src import effects
from jax._src import util
from jax._src.callback import _check_shape_dtype, callback_batching_rule
from jax._src.interpreters import ad
Expand Down Expand Up @@ -197,6 +198,7 @@ def ffi_call(
result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata],
*args: ArrayLike,
vectorized: bool = False,
has_side_effect: bool = False,
**kwargs: Any,
) -> Array | list[Array]:
"""Call a foreign function interface (FFI) target.
Expand All @@ -222,8 +224,11 @@ def ffi_call(
used to define the elements of ``result_shape_dtypes``.
``jax.core.abstract_token`` may be used to represent a token-typed output.
*args: the arguments passed to the custom call.
vectorized: boolean specifying whether the callback function can operate in
a vectorized manner, as described above.
vectorized: boolean specifying whether the FFI call can operate in a
vectorized manner, as described above.
has_side_effect: boolean specifying whether the custom call has side
effects. When ``True``, the FFI call will be executed even when the
outputs are not used.
**kwargs: keyword arguments that are passed as named attributes to the
custom call using XLA's FFI interface.

Expand All @@ -242,6 +247,7 @@ def ffi_call(
result_avals=result_avals,
vectorized=vectorized,
target_name=target_name,
has_side_effect=has_side_effect,
**kwargs,
)
if multiple_results:
Expand All @@ -250,15 +256,26 @@ def ffi_call(
return results[0]


class FfiEffect(effects.Effect):
def __str__(self):
return "FFI"

_FfiEffect = FfiEffect()
effects.lowerable_effects.add_type(FfiEffect)
effects.control_flow_allowed_effects.add_type(FfiEffect)


def ffi_call_abstract_eval(
*avals_in,
result_avals: tuple[core.AbstractValue, ...],
target_name: str,
vectorized: bool,
has_side_effect: bool,
**kwargs: Any,
):
del avals_in, target_name, vectorized, kwargs
return result_avals
effects = {_FfiEffect} if has_side_effect else core.no_effects
return result_avals, effects


def ffi_call_jvp(*args, target_name, **kwargs):
Expand All @@ -281,16 +298,18 @@ def ffi_call_lowering(
result_avals: tuple[core.AbstractValue, ...],
target_name: str,
vectorized: bool,
has_side_effect: bool,
**kwargs: Any,
) -> Sequence[ir.Value]:
del result_avals, vectorized
return ffi_lowering(target_name)(ctx, *operands, **kwargs)
rule = ffi_lowering(target_name, has_side_effect=has_side_effect)
return rule(ctx, *operands, **kwargs)


ffi_call_p = core.Primitive("ffi_call")
ffi_call_p.multiple_results = True
ffi_call_p.def_impl(functools.partial(dispatch.apply_primitive, ffi_call_p))
ffi_call_p.def_abstract_eval(ffi_call_abstract_eval)
ffi_call_p.def_effectful_abstract_eval(ffi_call_abstract_eval)
ad.primitive_jvps[ffi_call_p] = ffi_call_jvp
ad.primitive_transposes[ffi_call_p] = ffi_call_transpose
batching.primitive_batchers[ffi_call_p] = functools.partial(
Expand Down
20 changes: 20 additions & 0 deletions tests/extend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import unittest

import numpy as np
from absl.testing import absltest
Expand Down Expand Up @@ -178,6 +179,25 @@ def fun():
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))

def testEffectsHlo(self):
# The target name must exist on the current platform, but we don't actually
# need to call it with the correct syntax, because we're only checking the
# compiled HLO.
if jtu.test_device_matches(["cpu"]):
target_name = "lapack_sgetrf_ffi"
elif jtu.test_device_matches(["rocm"]):
target_name = "hipsolver_getrf_ffi"
elif jtu.test_device_matches(["cuda", "gpu"]):
target_name = "cusolver_getrf_ffi"
else:
raise unittest.SkipTest("Unsupported device")
def fun():
jex.ffi.ffi_call(target_name, (), has_side_effect=True)
hlo = jax.jit(fun).lower()
self.assertIn(target_name, hlo.as_text())
self.assertIn("has_side_effect = true", hlo.as_text())
self.assertIn(target_name, hlo.compile().as_text())

@jtu.sample_product(
shape=[(1,), (4,), (5,)],
dtype=(np.int32,),
Expand Down