Skip to content

Commit

Permalink
[inductor] Reduce block sizes when using Triton CPU backend (pytorch#…
Browse files Browse the repository at this point in the history
…136612)

This greatly reduces compile time; TorchBench models that were previously 50-100x slower (vs the cpp backend) are now ~20x slower. More work needs to be done on the Triton side, but smaller block sizes will still be helpful.

Pull Request resolved: pytorch#136612
Approved by: https://github.com/desertfire
ghstack dependencies: pytorch#135342
  • Loading branch information
int3 authored and pytorchmergebot committed Oct 3, 2024
1 parent 4513fb5 commit b3953ff
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 18 deletions.
17 changes: 15 additions & 2 deletions torch/_inductor/kernel/bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ def bmm_grid(b, m, n, meta):
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1)


def _is_large_block_for_cpu(m, n, k):
# Thresholds are experimentally determined to reduce Triton CPU compile times
if m > 128 or n > 128 or k > 128:
return True
return m * n > 2**12


def bmm_configs(m, n, k, *, device_type):
if device_type == "cpu":
return mm_configs(m, n, k, scale=0.5, exclude=_is_large_block_for_cpu)
return mm_configs(m, n, k)


bmm_template = TritonTemplate(
name="bmm",
grid=bmm_grid,
Expand Down Expand Up @@ -147,7 +160,7 @@ def may_require_contiguous(t, meta_t):
# options to tune from
choices = [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else []
if use_triton_template(layout):
for config in mm_configs(m, n, k):
for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)):
bmm_template.maybe_append_choice(
choices,
input_nodes=(mat1, mat2),
Expand Down Expand Up @@ -179,7 +192,7 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
else []
)
if use_triton_template(layout):
for config in mm_configs(m, n, k):
for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)):
bmm_template.maybe_append_choice(
choices,
input_nodes=(inp, mat1, mat2),
Expand Down
26 changes: 21 additions & 5 deletions torch/_inductor/kernel/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# mypy: allow-untyped-defs
from __future__ import annotations

import functools
import logging
from typing import cast, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypedDict

Expand Down Expand Up @@ -83,10 +82,26 @@ def conv3d_grid(n, c, d, h, w, meta):
(config[0], config[1], config[2], 1, config[4]) for config in platform_configs
)

conv_configs = functools.partial(
filtered_configs,
configs=platform_configs,
)

def _is_large_block_for_cpu(m, n, k):
# Thresholds are experimentally determined to reduce Triton CPU compile times
if m > 256 or n > 256 or k > 256:
return True
return m * n * k > 2**17


def conv_configs(m, n, k, *, device_type, **kwargs):
if device_type == "cpu":
return filtered_configs(
m,
n,
k,
configs=platform_configs,
scale=0.5,
exclude=_is_large_block_for_cpu,
)
return filtered_configs(m, n, k, configs=platform_configs)


LOOP_BODY_2D = """
idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H
Expand Down Expand Up @@ -599,6 +614,7 @@ def channels_last_conv():
sympy_product([x.get_size()[0], *x.get_size()[2:]]),
out_chan,
in_chan,
device_type=ir.get_device_type(x),
):
if ndim == 2:
conv2d_template.maybe_append_choice(
Expand Down
40 changes: 33 additions & 7 deletions torch/_inductor/kernel/mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch._inductor.codegen.cpp_gemm_template import CppPackedGemmTemplate
from torch._inductor.virtualized import V

from .. import config as inductor_config
from .. import config as inductor_config, ir
from ..codegen.common import BackendFeature
from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate
from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate
Expand Down Expand Up @@ -148,6 +148,20 @@ def _is_int8_mat(mat):
return mat.get_dtype() in (torch.int8, torch.uint8)


def _is_large_block_for_cpu(m, n, k):
# Thresholds are experimentally determined to reduce Triton CPU compile times
return m * n > 2**13


def mm_config_kwargs(device):
if device == "cpu":
return {
"scale": 0.5,
"exclude": _is_large_block_for_cpu,
}
return {}


def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
"""
Giving torch.addmm a 1D tensor calls a different (faster) cublasLt
Expand Down Expand Up @@ -179,7 +193,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
)
static_shape, is_nonzero = _is_static_problem([mat1, mat2], layout)
if is_nonzero and use_triton_template(layout):
for config in mm_configs(m, n, k):
for config in mm_configs(m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))):
mm_template.maybe_append_choice(
choices,
input_nodes=(mat1, mat2),
Expand Down Expand Up @@ -210,7 +224,9 @@ def tuned_mm(mat1, mat2, *, layout=None):
if use_aten_gemm_kernels():
always_included.append("extern_mm")
num_choices_before_extra_configs = len(choices)
for config in extra_mm_configs(m, n, k):
for config in extra_mm_configs(
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
):
mm_template.maybe_append_choice(
choices,
input_nodes=(mat1, mat2),
Expand Down Expand Up @@ -308,7 +324,9 @@ def tuned_int_mm(mat1, mat2, *, layout=None):
choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True
)
if is_nonzero and use_triton_template(layout, enable_int32=True):
for config in int8_mm_configs(m, n, k):
for config in int8_mm_configs(
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
):
mm_template.maybe_append_choice(
choices,
input_nodes=(mat1, mat2),
Expand Down Expand Up @@ -387,7 +405,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
)

if is_nonzero and use_triton_template(layout):
for config in mm_configs(m, n, k):
for config in mm_configs(m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))):
mm_template.maybe_append_choice(
choices,
input_nodes=(inp_expanded, mat1, mat2),
Expand Down Expand Up @@ -721,7 +739,13 @@ def tuned_mixed_mm(mat1, mat2, mat2_dtype):
choices.append(fallback)

has_int8_tensor = _is_int8_mat(mat1) or _is_int8_mat(mat2)
for config in mixed_mm_configs(m, n, k, has_int8_tensor=has_int8_tensor):
for config in mixed_mm_configs(
m,
n,
k,
has_int8_tensor=has_int8_tensor,
**mm_config_kwargs(ir.get_device_type(mat1)),
):
mm_template.maybe_append_choice(
choices,
input_nodes=(mat1, mat2),
Expand Down Expand Up @@ -778,7 +802,9 @@ def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None):
mat1, mat2, mat3, layout=layout, out_dtype=out_dtype
)
choices: List[Dict[Any, Any]] = []
for config in int8_mm_configs(m, n, k):
for config in int8_mm_configs(
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
):
mm_template.maybe_append_choice(
choices,
input_nodes=(mat1, mat2, mat3),
Expand Down
19 changes: 15 additions & 4 deletions torch/_inductor/kernel/mm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,15 @@ def filtered_configs(
k: int,
configs: Sequence[Tuple[int, int, int, int, int]],
has_int8_tensor=False,
scale=1,
exclude=lambda m, n, k: False,
):
"""Heuristic to shrink configs when they are bigger than the input size"""
"""
Heuristic to shrink configs when they are bigger than the input size
:param scale: scale factor applied to the config values
:param exclude: whether a given config should be excluded
"""

min_block_size = 16
# block_k=16 seems to be causing issues
Expand Down Expand Up @@ -64,9 +71,13 @@ def filtered_configs(
used = set()
for block_m, block_n, block_k, num_stages, num_warps in configs:
# shrink configs for small sizes
block_m = max(min(block_m, m), min_block_size)
block_n = max(min(block_n, n), min_block_size)
block_k = max(min(block_k, k), min_block_size_k)
block_m = max(min(int(block_m * scale), m), min_block_size)
block_n = max(min(int(block_n * scale), n), min_block_size)
block_k = max(min(int(block_k * scale), k), min_block_size_k)

if exclude(block_m, block_n, block_k):
continue

# each warp computes 16x16 tile = 256
num_warps = min(num_warps, block_m * block_n // 256)
if torch.version.hip:
Expand Down

0 comments on commit b3953ff

Please sign in to comment.