diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py index 833ac4f615a8..c4a73840b4fd 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/extend/ffi.py @@ -22,6 +22,7 @@ from jax._src import core from jax._src import dispatch +from jax._src import effects from jax._src import util from jax._src.callback import _check_shape_dtype, callback_batching_rule from jax._src.interpreters import ad @@ -197,6 +198,7 @@ def ffi_call( result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata], *args: ArrayLike, vectorized: bool = False, + has_side_effect: bool = False, **kwargs: Any, ) -> Array | list[Array]: """Call a foreign function interface (FFI) target. @@ -222,8 +224,11 @@ def ffi_call( used to define the elements of ``result_shape_dtypes``. ``jax.core.abstract_token`` may be used to represent a token-typed output. *args: the arguments passed to the custom call. - vectorized: boolean specifying whether the callback function can operate in - a vectorized manner, as described above. + vectorized: boolean specifying whether the FFI call can operate in a + vectorized manner, as described above. + has_side_effect: boolean specifying whether the custom call has side + effects. When ``True``, the FFI call will be executed even when the + outputs are not used. **kwargs: keyword arguments that are passed as named attributes to the custom call using XLA's FFI interface. @@ -242,6 +247,7 @@ def ffi_call( result_avals=result_avals, vectorized=vectorized, target_name=target_name, + has_side_effect=has_side_effect, **kwargs, ) if multiple_results: @@ -250,15 +256,26 @@ def ffi_call( return results[0] +class FfiEffect(effects.Effect): + def __str__(self): + return "FFI" + +_FfiEffect = FfiEffect() +effects.lowerable_effects.add_type(FfiEffect) +effects.control_flow_allowed_effects.add_type(FfiEffect) + + def ffi_call_abstract_eval( *avals_in, result_avals: tuple[core.AbstractValue, ...], target_name: str, vectorized: bool, + has_side_effect: bool, **kwargs: Any, ): del avals_in, target_name, vectorized, kwargs - return result_avals + effects = {_FfiEffect} if has_side_effect else core.no_effects + return result_avals, effects def ffi_call_jvp(*args, target_name, **kwargs): @@ -281,16 +298,18 @@ def ffi_call_lowering( result_avals: tuple[core.AbstractValue, ...], target_name: str, vectorized: bool, + has_side_effect: bool, **kwargs: Any, ) -> Sequence[ir.Value]: del result_avals, vectorized - return ffi_lowering(target_name)(ctx, *operands, **kwargs) + rule = ffi_lowering(target_name, has_side_effect=has_side_effect) + return rule(ctx, *operands, **kwargs) ffi_call_p = core.Primitive("ffi_call") ffi_call_p.multiple_results = True ffi_call_p.def_impl(functools.partial(dispatch.apply_primitive, ffi_call_p)) -ffi_call_p.def_abstract_eval(ffi_call_abstract_eval) +ffi_call_p.def_effectful_abstract_eval(ffi_call_abstract_eval) ad.primitive_jvps[ffi_call_p] = ffi_call_jvp ad.primitive_transposes[ffi_call_p] = ffi_call_transpose batching.primitive_batchers[ffi_call_p] = functools.partial( diff --git a/tests/extend_test.py b/tests/extend_test.py index fff3314a7656..cacad4f6e452 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import unittest import numpy as np from absl.testing import absltest @@ -178,6 +179,25 @@ def fun(): self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type)) self.assertTrue(hlo.TokenType.isinstance(op.results[0].type)) + def testEffectsHlo(self): + # The target name must exist on the current platform, but we don't actually + # need to call it with the correct syntax, because we're only checking the + # compiled HLO. + if jtu.test_device_matches(["cpu"]): + target_name = "lapack_sgetrf_ffi" + elif jtu.test_device_matches(["rocm"]): + target_name = "hipsolver_getrf_ffi" + elif jtu.test_device_matches(["cuda", "gpu"]): + target_name = "cusolver_getrf_ffi" + else: + raise unittest.SkipTest("Unsupported device") + def fun(): + jex.ffi.ffi_call(target_name, (), has_side_effect=True) + hlo = jax.jit(fun).lower() + self.assertIn(target_name, hlo.as_text()) + self.assertIn("has_side_effect = true", hlo.as_text()) + self.assertIn(target_name, hlo.compile().as_text()) + @jtu.sample_product( shape=[(1,), (4,), (5,)], dtype=(np.int32,),