Skip to content

Commit

Permalink
Remove the `jax_require_devices_during_lowering flag since it was tem…
Browse files Browse the repository at this point in the history
…porary. Added the semi-breaking change to Changelog.md.

PiperOrigin-RevId: 590684939
  • Loading branch information
yashk2810 authored and jax authors committed Dec 13, 2023
1 parent 0281596 commit 3c07c10
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 20 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.22

* Changes
* JAX lowering to StableHLO does not depend on physical devices anymore.
If your primitive wraps custom_paritioning or JAX callbacks in the lowering
rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your
primitive to `jax._src.dispatch.prim_requires_devices_during_lowering` set.
This is needed because custom_partitioning and JAX callbacks need physical
devices to create `Sharding`s during lowering.
This is a temporary state until we can create `Sharding`s without physical
devices.

* Deprecations
* The `device_buffer` and `device_buffers` properties of JAX arrays are deprecated.
Explicit buffers have been replaced by the more flexible array sharding interface,
Expand Down
14 changes: 2 additions & 12 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,6 @@ class WeakRefList(list):
MeshDimAssignment = Union[ShardedAxis, Replicated]
ShardingSpec = sharding_specs.ShardingSpec

# TODO(yashkatariya): Remove this flag when nvidia's use cases are fixed.
_JAX_REQUIRE_DEVICES_DURING_LOWERING = config.DEFINE_bool(
"jax_require_devices_during_lowering",
True,
help="Forces physical devices to be passed during lowering to stablehlo.")

### util

def identity(x): return x
Expand Down Expand Up @@ -1977,17 +1971,13 @@ def lower_sharding_computation(
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
semantic_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
materialized_da = (
tuple(da_object)
if prim_requires_devices or _JAX_REQUIRE_DEVICES_DURING_LOWERING.value
else None)

(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
semantic_out_shardings, in_layouts, out_layouts, len(da_object),
materialized_da, donated_invars, name_stack, all_default_mem_kind,
lowering_parameters=lowering_parameters)
tuple(da_object) if prim_requires_devices else None, donated_invars,
name_stack, all_default_mem_kind, lowering_parameters=lowering_parameters)

# backend and device_assignment is passed through to MeshExecutable because
# if keep_unused=False and all in_shardings are pruned, then there is no way
Expand Down
11 changes: 3 additions & 8 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3859,14 +3859,9 @@ def g(a):
b = jax.device_put(out_a, NamedSharding(mesh2, P('y')))
f(b) # lowering cache *hit*

prev_value = pxla._JAX_REQUIRE_DEVICES_DURING_LOWERING.value
try:
jax.config.update('jax_require_devices_during_lowering', False)
with jtu.count_jit_and_pmap_compiles() as count:
g(np.arange(8))
self.assertEqual(count[0], 1)
finally:
jax.config.update('jax_require_devices_during_lowering', prev_value)
with jtu.count_jit_and_pmap_compiles() as count:
g(np.arange(8))
self.assertEqual(count[0], 1)

def test_lowering_cache_miss_different_devices_and_sharding(self):
if jax.device_count() < 4:
Expand Down

0 comments on commit 3c07c10

Please sign in to comment.