Skip to content

Commit

Permalink
[Inductor] Handle device property warp_size is None but used on XPU. (
Browse files Browse the repository at this point in the history
  • Loading branch information
etaf authored and pytorchmergebot committed Sep 30, 2024
1 parent 6931c16 commit 0a26851
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion test/inductor/test_triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def pre_hook(kwargs):
autotuner = CachingAutotuner(**args)

def test_autotune_hints_to_configs(self):
device_props = DeviceProperties.create(torch.device("cuda"))
device_props = DeviceProperties.create(torch.device(GPU_TYPE))
device_props = device_props._replace(warp_size=8)

hints = {AutotuneHint.ONE_ELEMENT_PER_THREAD}
Expand Down
14 changes: 9 additions & 5 deletions torch/_inductor/runtime/hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,23 @@ def create(cls, device):
device_type = "hip"

device_interface = get_interface_for_device(device)
if device_type in ["cuda", "hip"]:
if device_type in ["cuda", "hip", "xpu"]:
props = device_interface.get_device_properties(device)
return cls(
type=device_type,
index=device.index,
cc=device_interface.get_compute_capability(device),
major=props.major,
major=props.major if hasattr(props, "major") else None,
regs_per_multiprocessor=props.regs_per_multiprocessor
if hasattr(props, "regs_per_multiprocessor")
else None,
max_threads_per_multi_processor=props.max_threads_per_multi_processor,
multi_processor_count=props.multi_processor_count,
warp_size=props.warp_size,
max_threads_per_multi_processor=props.max_threads_per_multi_processor
if hasattr(props, "max_threads_per_multi_processor")
else None,
multi_processor_count=props.multi_processor_count
if hasattr(props, "multi_processor_count")
else None,
warp_size=props.warp_size if hasattr(props, "warp_size") else 32,
)
return cls(
type=device_type,
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/runtime/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def autotune_hints_to_configs(
triton_config(
size_hints,
*xyz,
num_elements_per_warp=device_props.warp_size,
num_elements_per_warp=device_props.warp_size
if device_props.warp_size
else 32,
)
)

Expand Down

0 comments on commit 0a26851

Please sign in to comment.