Skip to content

Commit

Permalink
[mosaic:gpu] Rename once -> single_thread.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 640475714
  • Loading branch information
chr1sj0nes authored and jax authors committed Jun 5, 2024
1 parent 485fad5 commit 557cae6
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 11 deletions.
4 changes: 2 additions & 2 deletions jax/experimental/mosaic/gpu/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
memref_transpose,
memref_unfold,
memref_unsqueeze,
once,
tile_shape,
single_thread,
thread_idx,
tile_shape,
warp_idx,
warpgroup_idx,
)
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/mosaic/gpu/examples/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def kernel(
)

def kv_copy_init(slot, kv_seq_base):
with once():
with single_thread():
txcount = c(2 * blocks.kv * head_dim * bytewidth(f16))
nvgpu.mbarrier_arrive_expect_tx(barriers.value, txcount, slot)
k_tr = (
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/mosaic/gpu/examples/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def fetch(slot, ki):
common_copy_args = dict(
swizzle=128, barrier=barrier, arrive=False, uniform=False,
)
with once():
with single_thread():
nvgpu.mbarrier_arrive_expect_tx(barrier_group.value, txcount, slot)
ctx.async_copy(
src_ref=a_device,
Expand Down
11 changes: 4 additions & 7 deletions jax/experimental/mosaic/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def debug_print(fmt, *args, uniform=True):
raise NotImplementedError(arg.type)
type_formats.append(ty_format)
new_args.append(arg)
ctx = once if uniform else contextlib.nullcontext
ctx = single_thread if uniform else contextlib.nullcontext
with ctx():
gpu.printf(fmt.format(*type_formats) + "\n", new_args)

Expand Down Expand Up @@ -219,11 +219,8 @@ def warpgroup_idx(sync=True):


@contextlib.contextmanager
def once():
"""Runs the context only from a single thread from the first warp.
The block is assumed to have a size of 1 in both y and z dimensions.
"""
def single_thread():
"""Runs the context only from a single thread."""
global _ONCE_REGION_ACTIVE

if _ONCE_REGION_ACTIVE:
Expand Down Expand Up @@ -502,7 +499,7 @@ def __init__(self, num_barriers: int, arrival_count: int = 1):
i32 = ir.IntegerType.get_signless(32)
self.phases = memref.alloca(ir.MemRefType.get((), i32), [], [])
memref.store(c(0, i32), self.phases, [])
with once():
with single_thread():
for i in range(num_barriers):
nvgpu.mbarrier_init(self.value, c(arrival_count, index), c(i, index))
gpu.barrier()
Expand Down

0 comments on commit 557cae6

Please sign in to comment.