diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py index 428ff06d4a052..7c6ba90398998 100644 --- a/torch/_inductor/kernel/bmm.py +++ b/torch/_inductor/kernel/bmm.py @@ -28,6 +28,19 @@ def bmm_grid(b, m, n, meta): return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1) +def _is_large_block_for_cpu(m, n, k): + # Thresholds are experimentally determined to reduce Triton CPU compile times + if m > 128 or n > 128 or k > 128: + return True + return m * n > 2**12 + + +def bmm_configs(m, n, k, *, device_type): + if device_type == "cpu": + return mm_configs(m, n, k, scale=0.5, exclude=_is_large_block_for_cpu) + return mm_configs(m, n, k) + + bmm_template = TritonTemplate( name="bmm", grid=bmm_grid, @@ -147,7 +160,7 @@ def may_require_contiguous(t, meta_t): # options to tune from choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else [] if use_triton_template(layout): - for config in mm_configs(m, n, k): + for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)): bmm_template.maybe_append_choice( choices, input_nodes=(mat1, mat2), @@ -179,7 +192,7 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): else [] ) if use_triton_template(layout): - for config in mm_configs(m, n, k): + for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)): bmm_template.maybe_append_choice( choices, input_nodes=(inp, mat1, mat2), diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py index b69143fe03015..71e3a21b005ef 100644 --- a/torch/_inductor/kernel/conv.py +++ b/torch/_inductor/kernel/conv.py @@ -2,7 +2,6 @@ # mypy: allow-untyped-defs from __future__ import annotations -import functools import logging from typing import cast, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypedDict @@ -83,10 +82,26 @@ def conv3d_grid(n, c, d, h, w, meta): (config[0], config[1], config[2], 1, config[4]) for config in platform_configs ) -conv_configs = functools.partial( - filtered_configs, - configs=platform_configs, -) + +def _is_large_block_for_cpu(m, n, k): + # Thresholds are experimentally determined to reduce Triton CPU compile times + if m > 256 or n > 256 or k > 256: + return True + return m * n * k > 2**17 + + +def conv_configs(m, n, k, *, device_type, **kwargs): + if device_type == "cpu": + return filtered_configs( + m, + n, + k, + configs=platform_configs, + scale=0.5, + exclude=_is_large_block_for_cpu, + ) + return filtered_configs(m, n, k, configs=platform_configs) + LOOP_BODY_2D = """ idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H @@ -599,6 +614,7 @@ def channels_last_conv(): sympy_product([x.get_size()[0], *x.get_size()[2:]]), out_chan, in_chan, + device_type=ir.get_device_type(x), ): if ndim == 2: conv2d_template.maybe_append_choice( diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index ee4e759d9c14e..20b8c93ff0537 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -16,7 +16,7 @@ from torch._inductor.codegen.cpp_gemm_template import CppPackedGemmTemplate from torch._inductor.virtualized import V -from .. import config as inductor_config +from .. import config as inductor_config, ir from ..codegen.common import BackendFeature from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate @@ -148,6 +148,20 @@ def _is_int8_mat(mat): return mat.get_dtype() in (torch.int8, torch.uint8) +def _is_large_block_for_cpu(m, n, k): + # Thresholds are experimentally determined to reduce Triton CPU compile times + return m * n > 2**13 + + +def mm_config_kwargs(device): + if device == "cpu": + return { + "scale": 0.5, + "exclude": _is_large_block_for_cpu, + } + return {} + + def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): """ Giving torch.addmm a 1D tensor calls a different (faster) cublasLt @@ -179,7 +193,7 @@ def tuned_mm(mat1, mat2, *, layout=None): ) static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout) if is_nonzero and use_triton_template(layout): - for config in mm_configs(m, n, k): + for config in mm_configs(m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))): mm_template.maybe_append_choice( choices, input_nodes=(mat1, mat2), @@ -210,7 +224,9 @@ def tuned_mm(mat1, mat2, *, layout=None): if use_aten_gemm_kernels(): always_included.append("extern_mm") num_choices_before_extra_configs = len(choices) - for config in extra_mm_configs(m, n, k): + for config in extra_mm_configs( + m, n, k, **mm_config_kwargs(ir.get_device_type(mat1)) + ): mm_template.maybe_append_choice( choices, input_nodes=(mat1, mat2), @@ -308,7 +324,9 @@ def tuned_int_mm(mat1, mat2, *, layout=None): choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True ) if is_nonzero and use_triton_template(layout, enable_int32=True): - for config in int8_mm_configs(m, n, k): + for config in int8_mm_configs( + m, n, k, **mm_config_kwargs(ir.get_device_type(mat1)) + ): mm_template.maybe_append_choice( choices, input_nodes=(mat1, mat2), @@ -387,7 +405,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): ) if is_nonzero and use_triton_template(layout): - for config in mm_configs(m, n, k): + for config in mm_configs(m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))): mm_template.maybe_append_choice( choices, input_nodes=(inp_expanded, mat1, mat2), @@ -721,7 +739,13 @@ def tuned_mixed_mm(mat1, mat2, mat2_dtype): choices.append(fallback) has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2) - for config in mixed_mm_configs(m, n, k, has_int8_tensor=has_int8_tensor): + for config in mixed_mm_configs( + m, + n, + k, + has_int8_tensor=has_int8_tensor, + **mm_config_kwargs(ir.get_device_type(mat1)), + ): mm_template.maybe_append_choice( choices, input_nodes=(mat1, mat2), @@ -778,7 +802,9 @@ def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None): mat1, mat2, mat3, layout=layout, out_dtype=out_dtype ) choices: List[Dict[Any, Any]] = [] - for config in int8_mm_configs(m, n, k): + for config in int8_mm_configs( + m, n, k, **mm_config_kwargs(ir.get_device_type(mat1)) + ): mm_template.maybe_append_choice( choices, input_nodes=(mat1, mat2, mat3), diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 21ba6c1e215db..58bfa3a288242 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -30,8 +30,15 @@ def filtered_configs( k: int, configs: Sequence[Tuple[int, int, int, int, int]], has_int8_tensor=False, + scale=1, + exclude=lambda m, n, k: False, ): - """Heuristic to shrink configs when they are bigger than the input size""" + """ + Heuristic to shrink configs when they are bigger than the input size + + :param scale: scale factor applied to the config values + :param exclude: whether a given config should be excluded + """ min_block_size = 16 # block_k=16 seems to be causing issues @@ -64,9 +71,13 @@ def filtered_configs( used = set() for block_m, block_n, block_k, num_stages, num_warps in configs: # shrink configs for small sizes - block_m = max(min(block_m, m), min_block_size) - block_n = max(min(block_n, n), min_block_size) - block_k = max(min(block_k, k), min_block_size_k) + block_m = max(min(int(block_m * scale), m), min_block_size) + block_n = max(min(int(block_n * scale), n), min_block_size) + block_k = max(min(int(block_k * scale), k), min_block_size_k) + + if exclude(block_m, block_n, block_k): + continue + # each warp computes 16x16 tile = 256 num_warps = min(num_warps, block_m * block_n // 256) if torch.version.hip: