Skip to content

Commit

Permalink
[Inductor] Use parametrize to break down some unit tests (pytorch#137156
Browse files Browse the repository at this point in the history
)

Summary: To address the issue that some tests are marked as slow, see pytorch#136940 (comment)

Pull Request resolved: pytorch#137156
Approved by: https://github.com/eellison
  • Loading branch information
desertfire authored and pytorchmergebot committed Oct 3, 2024
1 parent 7631a04 commit 4513fb5
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 64 deletions.
34 changes: 25 additions & 9 deletions test/inductor/test_cuda_cpp_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Owner(s): ["module: inductor"]
import itertools
import sys
import unittest
from typing import NamedTuple
Expand Down Expand Up @@ -72,6 +73,7 @@ def make_test_case(
slow=False,
func_inputs=None,
code_string_count=None,
check_code=True,
):
test_name = f"{name}_{device}" if device else name
if code_string_count is None:
Expand All @@ -94,13 +96,14 @@ def fn(self):
_, code = test_torchinductor.run_and_get_cpp_code(
func, *func_inputs if func_inputs else []
)
self.assertEqual("CppWrapperCodeCache" in code, True)
self.assertTrue(
all(
code.count(string) == code_string_count[string]
for string in code_string_count
if check_code:
self.assertEqual("CppWrapperCodeCache" in code, True)
self.assertTrue(
all(
code.count(string) == code_string_count[string]
for string in code_string_count
)
)
)
finally:
tests.tearDown()
tests.tearDownClass()
Expand All @@ -124,6 +127,7 @@ class BaseTest(NamedTuple):
name: str
device: str = "cuda"
tests: InductorTestCase = test_torchinductor.GPUTests()
check_code: bool = True

# Maintain two separate test lists for cuda and cpp for now
for item in [
Expand Down Expand Up @@ -169,7 +173,10 @@ class BaseTest(NamedTuple):
BaseTest("test_sum_dtype"), # float64
BaseTest("test_sum_int"), # bool, int64, int8, uint8
BaseTest("test_transpose"), # multiple outputs, buffer clear
BaseTest("test_unspec_inputs"),
*[
BaseTest(f"test_unspec_inputs_{str(dtype)[6:]}")
for dtype in test_torchinductor.test_dtypes
],
BaseTest("test_consecutive_split_cumprod"),
BaseTest("test_pointwise_hermite_polynomial_he"),
BaseTest("test_pointwise_hermite_polynomial_h"),
Expand Down Expand Up @@ -208,7 +215,16 @@ class BaseTest(NamedTuple):
),
BaseTest("test_fft_real_input"),
BaseTest("test_fft_real_input_real_output"),
BaseTest("test_dtypeview"),
*[
# some dtypes may raise exception and be skipped in test_dtypeview, so set check_code to False here
BaseTest(
f"test_dtypeview_{str(dtype_x)[6:]}_{str(dtype_y)[6:]}",
check_code=False,
)
for dtype_x, dtype_y in itertools.product(
test_torchinductor.test_dtypes, test_torchinductor.test_dtypes
)
],
BaseTest("test_dtypeview_fusion"),
# skip if not enough SMs
BaseTest(
Expand All @@ -221,7 +237,7 @@ class BaseTest(NamedTuple):
tests=test_select_algorithm.TestSelectAlgorithm(),
),
]:
make_test_case(item.name, item.device, item.tests)
make_test_case(item.name, item.device, item.tests, check_code=item.check_code)

from torch._inductor.utils import is_big_gpu

Expand Down
111 changes: 56 additions & 55 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,19 @@
i64 = torch.int64
i32 = torch.int32

test_dtypes = [
torch.float32,
torch.float64,
torch.float16,
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
]
if SM80OrLater:
test_dtypes.append(torch.bfloat16)


def _large_cumprod_input(shape, dim, dtype, device):
# Construct a cumprod input which guaruntees not to overflow or underflow
Expand Down Expand Up @@ -9262,32 +9275,32 @@ def fn1(i0, i1):
self.common(fn0, [torch.rand(10, 3, 10), torch.rand(3, 10, 10)])
self.common(fn1, [torch.rand(3, 10, 10), torch.rand(3, 10, 10)])

@skip_if_gpu_halide # https://github.com/halide/Halide/issues/8318
def test_unspec_inputs(self):
@parametrize(
"dtype",
test_dtypes,
)
def test_unspec_inputs(self, dtype):
if self.device == "cpu":
raise unittest.SkipTest("Testing mixed devices")

if (
is_halide_backend(self.device)
and getattr(self.device, "type", self.device) == "cuda"
):
# https://github.com/halide/Halide/issues/8318
raise unittest.SkipTest("halide not supported")

def fn(x, y):
return x + y, x * y, x / y

opt = torch._dynamo.optimize("inductor")(fn)
dtypes = [
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
torch.int32,
torch.int64,
]

for d in dtypes:
inputs = (
rand_strided((2, 3), (3, 1), dtype=torch.float32, device=GPU_TYPE),
rand_strided((), (), dtype=d, device="cpu"),
)
self.assertTrue(same(opt(*inputs), fn(*inputs)))
inputs = (inputs[1], inputs[0])
self.assertTrue(same(opt(*inputs), fn(*inputs)))
inputs = (
rand_strided((2, 3), (3, 1), dtype=torch.float32, device=GPU_TYPE),
rand_strided((), (), dtype=dtype, device="cpu"),
)
self.assertTrue(same(opt(*inputs), fn(*inputs)))
inputs = (inputs[1], inputs[0])
self.assertTrue(same(opt(*inputs), fn(*inputs)))

@dynamo_config.patch(automatic_dynamic_shapes=True)
def test_list_clearing(self):
Expand Down Expand Up @@ -11048,53 +11061,41 @@ def fn(x):
self.common(fn, (x,))

@skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80")
# We only support dtypeview for abi_conpatible aoti
@skip_if_triton_cpu("Compile time crash in Triton CPU CI")
# We only support dtypeview for abi_compatible aoti
@torch._inductor.config.patch(abi_compatible=True)
def test_dtypeview(self):
@parametrize(
"dtype_x, dtype_y",
list(itertools.product(test_dtypes, test_dtypes)),
)
def test_dtypeview(self, dtype_x, dtype_y):
if TEST_WITH_ASAN:
return

if is_triton_cpu_backend(self.device):
raise unittest.SkipTest("Compile time crash in Triton CPU CI")

# https://github.com/pytorch/pytorch/issues/126338
def fn(x, y, x_dtype, x2):
x = x.view(x_dtype)
y = y.view(x_dtype) + 1
x2 = x2.view(x_dtype) + 1
return x @ y, x2 @ x

test_dtypes = [
torch.float32,
torch.float64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
]
for test_dtype_x in test_dtypes:
for test_dtype_y in test_dtypes:
# @ operation needs arguments to be the same dtype
for view_dtype in test_dtypes:
try:
# print(f"({test_dtype_x}, {test_dtype_y}, {view_dtype})")
x = rand_strided(
(2, 2), (2, 1), device=self.device, dtype=test_dtype_x
)
y = rand_strided(
(2, 2), (2, 1), device=self.device, dtype=test_dtype_y
)
x2 = x.clone()
fn(x, y, view_dtype, x2)
except Exception as e:
continue
self.common(
fn,
(x, y, view_dtype, x2),
reference_in_float=False,
check_lowp=False,
)
# @ operation needs arguments to be the same dtype
for view_dtype in test_dtypes:
try:
x = rand_strided((2, 2), (2, 1), device=self.device, dtype=dtype_x)
y = rand_strided((2, 2), (2, 1), device=self.device, dtype=dtype_y)
x2 = x.clone()
fn(x, y, view_dtype, x2)
except Exception as e:
continue
self.common(
fn,
(x, y, view_dtype, x2),
reference_in_float=False,
check_lowp=False,
)

@torch._inductor.config.patch(abi_compatible=True)
def test_dtypeview_fusion(self):
Expand Down

0 comments on commit 4513fb5

Please sign in to comment.