Skip to content

Commit

Permalink
[pallas] Clean up forward-compatibility conditionals in Pallas lowering
Browse files Browse the repository at this point in the history
In cl/657184114 (July 29th) I have made some changes in error reporting for invalid block shapes, but have left behind some conditionals to ensure forward compatibility. We are now out of the forward compatibility windows, and we clean up those conditionals.

PiperOrigin-RevId: 674603915
  • Loading branch information
gnecula authored and Google-ML-Automation committed Sep 14, 2024
1 parent 0daca46 commit ee6f098
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 52 deletions.
52 changes: 12 additions & 40 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.lax.control_flow import for_loop
from jax._src.lib import version as jaxlib_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import func
Expand Down Expand Up @@ -448,20 +447,10 @@ def err_details():
f"and index_map returning {bm.index_map_jaxpr.jaxpr.outvars}, in "
f"memory space {bm.block_aval.memory_space}."
"\nSee details at https://jax.readthedocs.io/en/latest/pallas/grid_blockspec.html#pallas-blockspec")
if lowering_context.is_forward_compat() or jaxlib_version < (0, 4, 32):
# TODO(b/356116061): Remove the old rank condition
if rank < 2:
raise ValueError(
"The Pallas TPU lowering currently supports only blocks of "
"rank >= 2 for blocks, except those in the SMEM memory space "
"having the same block shape as the array shape and a "
"trivial index_map (returning all 0s). " + err_details())
else:
if rank < 1:
raise ValueError(
"The Pallas TPU lowering currently supports only blocks of "
"rank >= 1. " + err_details())

if rank < 1:
raise ValueError(
"The Pallas TPU lowering currently supports only blocks of "
"rank >= 1. " + err_details())

if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.ANY and
not bm.has_trivial_window()):
Expand All @@ -476,34 +465,17 @@ def err_details():
bs1, as1 = unmapped_bs[-2], bm.array_shape_dtype.shape[-2]
else:
bs1, as1 = 1, 1
if lowering_context.is_forward_compat():
# TODO(b/356116061): Remove the old divisibility condition
# With shape polymorphism block_shape is static, but the array shape may
# be symbolic. Write the divisibility comparisons to defer inequality
# comparisons on dimensions as much as possible.

if rank >= 2:
evenly_divisible = (
(bs0 % 128 == 0 or (bs0 == as0 and as0 < 128)) and
(bs1 % 8 == 0 or (bs1 == as1 and as1 < 8))
(bs0 == as0 or bs0 % 128 == 0) and
(bs1 == as1 or bs1 % 8 == 0)
)
if not evenly_divisible:
raise ValueError(
"The Pallas TPU lowering currently requires that the last two "
"dimensions of your block shape are divisible by 8 and 128 "
"respectively, if the respective dimensions of the overall array "
"are larger than the respective factors. If array dimensions are "
"smaller, the block should span the full array dimension. "
+ err_details())
else:
if rank >= 2:
evenly_divisible = (
(bs0 == as0 or bs0 % 128 == 0) and
(bs1 == as1 or bs1 % 8 == 0)
)
else:
assert rank == 1
# TODO(necula): test this for bool. What should it do?
tiling_size = 128 * (32 // lax_internal._bit_width(bm.array_shape_dtype.dtype))
evenly_divisible = (bs0 == as0 or bs0 % tiling_size == 0)
assert rank == 1
# TODO(necula): test this for bool. What should it do?
tiling_size = 128 * (32 // lax_internal._bit_width(bm.array_shape_dtype.dtype))
evenly_divisible = (bs0 == as0 or bs0 % tiling_size == 0)

if not evenly_divisible:
raise ValueError(
Expand Down
16 changes: 4 additions & 12 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax.control_flow.for_loop import for_loop
from jax._src.lib import version as jaxlib_version
from jax._src.pallas import core as pallas_core
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
from jax.experimental import pallas as pl
Expand Down Expand Up @@ -371,17 +370,10 @@ def copy_kernel(x_ref, o_ref):

test_context = contextlib.nullcontext()
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
if jaxlib_version < (0, 4, 32):
# TODO(b/356116061): Remove the old rank condition
if rank < 2:
test_context = self.assertRaisesRegex(
ValueError,
"TPU lowering currently supports only blocks of rank >= 2")
else:
if rank < 1:
test_context = self.assertRaisesRegex(
ValueError,
"TPU lowering currently supports only blocks of rank >= 1")
if rank < 1:
test_context = self.assertRaisesRegex(
ValueError,
"TPU lowering currently supports only blocks of rank >= 1")

if rank >= 1:
bs0, as0 = block_shape[-1], shape[-1]
Expand Down

0 comments on commit ee6f098

Please sign in to comment.