Skip to content

Commit

Permalink
Merge pull request #23982 from dfm:ffi-call-effects
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 679780356
  • Loading branch information
Google-ML-Automation committed Sep 28, 2024
2 parents 061f435 + d80a89d commit 15024ba
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 5 deletions.
29 changes: 24 additions & 5 deletions jax/_src/extend/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from jax._src import core
from jax._src import dispatch
from jax._src import effects
from jax._src import util
from jax._src.callback import _check_shape_dtype, callback_batching_rule
from jax._src.interpreters import ad
Expand Down Expand Up @@ -197,6 +198,7 @@ def ffi_call(
result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata],
*args: ArrayLike,
vectorized: bool = False,
has_side_effect: bool = False,
**kwargs: Any,
) -> Array | list[Array]:
"""Call a foreign function interface (FFI) target.
Expand All @@ -222,8 +224,11 @@ def ffi_call(
used to define the elements of ``result_shape_dtypes``.
``jax.core.abstract_token`` may be used to represent a token-typed output.
*args: the arguments passed to the custom call.
vectorized: boolean specifying whether the callback function can operate in
a vectorized manner, as described above.
vectorized: boolean specifying whether the FFI call can operate in a
vectorized manner, as described above.
has_side_effect: boolean specifying whether the custom call has side
effects. When ``True``, the FFI call will be executed even when the
outputs are not used.
**kwargs: keyword arguments that are passed as named attributes to the
custom call using XLA's FFI interface.
Expand All @@ -242,6 +247,7 @@ def ffi_call(
result_avals=result_avals,
vectorized=vectorized,
target_name=target_name,
has_side_effect=has_side_effect,
**kwargs,
)
if multiple_results:
Expand All @@ -250,15 +256,26 @@ def ffi_call(
return results[0]


class FfiEffect(effects.Effect):
def __str__(self):
return "FFI"

_FfiEffect = FfiEffect()
effects.lowerable_effects.add_type(FfiEffect)
effects.control_flow_allowed_effects.add_type(FfiEffect)


def ffi_call_abstract_eval(
*avals_in,
result_avals: tuple[core.AbstractValue, ...],
target_name: str,
vectorized: bool,
has_side_effect: bool,
**kwargs: Any,
):
del avals_in, target_name, vectorized, kwargs
return result_avals
effects = {_FfiEffect} if has_side_effect else core.no_effects
return result_avals, effects


def ffi_call_jvp(*args, target_name, **kwargs):
Expand All @@ -281,16 +298,18 @@ def ffi_call_lowering(
result_avals: tuple[core.AbstractValue, ...],
target_name: str,
vectorized: bool,
has_side_effect: bool,
**kwargs: Any,
) -> Sequence[ir.Value]:
del result_avals, vectorized
return ffi_lowering(target_name)(ctx, *operands, **kwargs)
rule = ffi_lowering(target_name, has_side_effect=has_side_effect)
return rule(ctx, *operands, **kwargs)


ffi_call_p = core.Primitive("ffi_call")
ffi_call_p.multiple_results = True
ffi_call_p.def_impl(functools.partial(dispatch.apply_primitive, ffi_call_p))
ffi_call_p.def_abstract_eval(ffi_call_abstract_eval)
ffi_call_p.def_effectful_abstract_eval(ffi_call_abstract_eval)
ad.primitive_jvps[ffi_call_p] = ffi_call_jvp
ad.primitive_transposes[ffi_call_p] = ffi_call_transpose
batching.primitive_batchers[ffi_call_p] = functools.partial(
Expand Down
20 changes: 20 additions & 0 deletions tests/extend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import unittest

import numpy as np
from absl.testing import absltest
Expand Down Expand Up @@ -178,6 +179,25 @@ def fun():
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))

def testEffectsHlo(self):
# The target name must exist on the current platform, but we don't actually
# need to call it with the correct syntax, because we're only checking the
# compiled HLO.
if jtu.test_device_matches(["cpu"]):
target_name = "lapack_sgetrf_ffi"
elif jtu.test_device_matches(["rocm"]):
target_name = "hipsolver_getrf_ffi"
elif jtu.test_device_matches(["cuda", "gpu"]):
target_name = "cusolver_getrf_ffi"
else:
raise unittest.SkipTest("Unsupported device")
def fun():
jex.ffi.ffi_call(target_name, (), has_side_effect=True)
hlo = jax.jit(fun).lower()
self.assertIn(target_name, hlo.as_text())
self.assertIn("has_side_effect = true", hlo.as_text())
self.assertIn(target_name, hlo.compile().as_text())

@jtu.sample_product(
shape=[(1,), (4,), (5,)],
dtype=(np.int32,),
Expand Down

0 comments on commit 15024ba

Please sign in to comment.