Skip to content

Commit

Permalink
[Inductor] Pick ISA for inductor based on ATEN_CPU_CAPABILITY (pytorc…
Browse files Browse the repository at this point in the history
…h#123514)

It is part of pytorch#123224. Pick ISA based on the environment ATEN_CPU_CAPABILITY to control CPU vec ISA level for Inductor like eager.

Pull Request resolved: pytorch#123514
Approved by: https://github.com/jgong5, https://github.com/peterbell10
  • Loading branch information
CaoE authored and pytorchmergebot committed Sep 30, 2024
1 parent 9dbc6ba commit 6931c16
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 25 deletions.
174 changes: 161 additions & 13 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import itertools
import math
import os
import platform
import sys
import unittest
Expand Down Expand Up @@ -60,12 +61,16 @@
check_model = test_torchinductor.check_model

requires_vectorization = unittest.skipUnless(
cpu_vec_isa.valid_vec_isa_list(), "Does not support vectorization"
cpu_vec_isa.valid_vec_isa_list() and os.getenv("ATEN_CPU_CAPABILITY") != "default",
"Does not support vectorization",
)


def check_metrics_vec_kernel_count(num_expected_vec_kernels):
if cpu_vec_isa.valid_vec_isa_list():
if (
cpu_vec_isa.valid_vec_isa_list()
and os.getenv("ATEN_CPU_CAPABILITY") != "default"
):
assert metrics.generated_cpp_vec_kernel_count == num_expected_vec_kernels


Expand Down Expand Up @@ -1586,6 +1591,78 @@ def fn(x):
metrics.reset()
self.common(fn, (value,))

@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
@unittest.skipIf(
not cpu_vec_isa.valid_vec_isa_list()
or "avx2" in [str(vec_isa) for vec_isa in cpu_vec_isa.valid_vec_isa_list()],
"Does not support vectorization or not s390x/aarch64/ppc64le machine",
)
@patch("torch.cuda.is_available", lambda: False)
def test_auto_zvec_neon_vsx_simd(self):
vec_zvec_neon_vsx = cpu_vec_isa.valid_vec_isa_list()[0]
self.assertTrue(vec_zvec_neon_vsx.bit_width() == 256)

with config.patch({"cpp.simdlen": 0}):
isa = cpu_vec_isa.pick_vec_isa()
self.assertFalse(isa)

with config.patch({"cpp.simdlen": 1}):
isa = cpu_vec_isa.pick_vec_isa()
self.assertFalse(isa)

with config.patch({"cpp.simdlen": 257}):
isa = cpu_vec_isa.pick_vec_isa()
self.assertFalse(isa)

with config.patch({"cpp.simdlen": 256}):
isa = cpu_vec_isa.pick_vec_isa()
self.assertTrue(isa == vec_zvec_neon_vsx)

pre_var = os.getenv("ATEN_CPU_CAPABILITY")
if pre_var:
os.environ.pop("ATEN_CPU_CAPABILITY")

try:
with config.patch({"cpp.simdlen": None}):
isa = cpu_vec_isa.pick_vec_isa()
self.assertTrue(isa == vec_zvec_neon_vsx)

with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "avx2"
isa = cpu_vec_isa.pick_vec_isa()
self.assertTrue(isa == vec_zvec_neon_vsx)

with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "avx512"
isa = cpu_vec_isa.pick_vec_isa()
self.assertTrue(isa == vec_zvec_neon_vsx)

with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "default"
isa = cpu_vec_isa.pick_vec_isa()
self.assertFalse(isa)

with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "neon"
isa = cpu_vec_isa.pick_vec_isa()
self.assertTrue(isa == vec_zvec_neon_vsx)

with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "zvector"
isa = cpu_vec_isa.pick_vec_isa()
self.assertTrue(isa == vec_zvec_neon_vsx)

with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "vsx"
isa = cpu_vec_isa.pick_vec_isa()
self.assertTrue(isa == vec_zvec_neon_vsx)

finally:
if pre_var:
os.environ["ATEN_CPU_CAPABILITY"] = pre_var
elif os.getenv("ATEN_CPU_CAPABILITY"):
os.environ.pop("ATEN_CPU_CAPABILITY")

@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
@unittest.skipIf(
platform.machine() != "x86_64" or not cpu_vec_isa.valid_vec_isa_list(),
Expand All @@ -1606,15 +1683,6 @@ def test_auto_simd(self):
self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32)
self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16)

with config.patch({"cpp.simdlen": None}):
isa = cpu_vec_isa.pick_vec_isa()
if vec_amx in cpu_vec_isa.valid_vec_isa_list():
self.assertTrue(isa == vec_amx)
elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list():
self.assertTrue(isa == vec_avx512)
else:
self.assertTrue(isa == vec_avx2)

with config.patch({"cpp.simdlen": 0}):
isa = cpu_vec_isa.pick_vec_isa()
self.assertFalse(isa)
Expand Down Expand Up @@ -1646,6 +1714,81 @@ def test_auto_simd(self):
isa = cpu_vec_isa.pick_vec_isa()
self.assertTrue(isa == vec_avx2)

pre_var = os.getenv("ATEN_CPU_CAPABILITY")
if pre_var:
os.environ.pop("ATEN_CPU_CAPABILITY")

try:
with config.patch({"cpp.simdlen": None}):
isa = cpu_vec_isa.pick_vec_isa()
if vec_amx in cpu_vec_isa.valid_vec_isa_list():
self.assertTrue(isa == vec_amx)
elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list():
self.assertTrue(isa == vec_avx512)
else:
self.assertTrue(isa == vec_avx2)

with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "avx2"
isa = cpu_vec_isa.pick_vec_isa()
if vec_amx in cpu_vec_isa.valid_vec_isa_list():
self.assertTrue(isa == vec_avx2)
elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list():
self.assertTrue(isa == vec_avx2)
elif vec_avx2 in cpu_vec_isa.valid_vec_isa_list():
self.assertTrue(isa == vec_avx2)

with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "avx512"
isa = cpu_vec_isa.pick_vec_isa()
if vec_amx in cpu_vec_isa.valid_vec_isa_list():
self.assertTrue(isa == vec_amx)
elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list():
self.assertTrue(isa == vec_avx512)
else:
self.assertTrue(isa == vec_avx2)

with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "default"
isa = cpu_vec_isa.pick_vec_isa()
self.assertFalse(isa)

with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "neon"
isa = cpu_vec_isa.pick_vec_isa()
if vec_amx in cpu_vec_isa.valid_vec_isa_list():
self.assertTrue(isa == vec_amx)
elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list():
self.assertTrue(isa == vec_avx512)
else:
self.assertTrue(isa == vec_avx2)

with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "zvector"
isa = cpu_vec_isa.pick_vec_isa()
if vec_amx in cpu_vec_isa.valid_vec_isa_list():
self.assertTrue(isa == vec_amx)
elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list():
self.assertTrue(isa == vec_avx512)
else:
self.assertTrue(isa == vec_avx2)

with config.patch({"cpp.simdlen": None}):
os.environ["ATEN_CPU_CAPABILITY"] = "vsx"
isa = cpu_vec_isa.pick_vec_isa()
if vec_amx in cpu_vec_isa.valid_vec_isa_list():
self.assertTrue(isa == vec_amx)
elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list():
self.assertTrue(isa == vec_avx512)
else:
self.assertTrue(isa == vec_avx2)

finally:
if pre_var:
os.environ["ATEN_CPU_CAPABILITY"] = pre_var
elif os.getenv("ATEN_CPU_CAPABILITY"):
os.environ.pop("ATEN_CPU_CAPABILITY")

@requires_vectorization
@patch("torch.cuda.is_available", lambda: False)
def test_masked_fill_softmax(self):
Expand Down Expand Up @@ -2626,6 +2769,7 @@ def fn(x, y):
1,
)

@requires_vectorization
def test_argmin(self):
def fn(x):
return torch.argmin(x, -1)
Expand All @@ -2637,6 +2781,7 @@ def fn(x):
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1

@requires_vectorization
def test_argmax_argmin_with_nan_value(self):
def fn(x):
return torch.argmax(x)
Expand Down Expand Up @@ -3521,6 +3666,7 @@ def forward(self, idx, x):
self.common(m, (idx, x))
check_metrics_vec_kernel_count(1)

@requires_vectorization
def test_embedding_vec_bf16(self):
class M(torch.nn.Module):
def __init__(self) -> None:
Expand Down Expand Up @@ -3862,7 +4008,7 @@ def fn(x):
x = torch.randint(0, 100, (819,), dtype=torch.int64)
metrics.reset()
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
check_metrics_vec_kernel_count(1)

def test_highp_to_lowp_cse_var_cache_with_store(self):
# Fix issue: https://github.com/pytorch/pytorch/issues/128263
Expand Down Expand Up @@ -3896,7 +4042,7 @@ def fn(x):
x = torch.randint(0, 100, (22, 51), dtype=torch.int64)
metrics.reset()
self.common(fn, (x,))
assert metrics.generated_cpp_vec_kernel_count == 1
check_metrics_vec_kernel_count(1)

@config.patch({"cpp.dynamic_threads": True})
def test_reduction_with_dynamic_threads(self):
Expand Down Expand Up @@ -4007,6 +4153,7 @@ def fn(arg0_1, arg0_2):
exactly=True,
).run(code)

@requires_vectorization
def test_repeated_exp(self):
def fn(x):
y = x.sigmoid()
Expand Down Expand Up @@ -4035,6 +4182,7 @@ def fn(x):
self.common(fn, (x,))
check_metrics_vec_kernel_count(1)

@requires_vectorization
def test_consistent_remove_buffers(self):
def fn(x):
z = x + x
Expand Down
5 changes: 4 additions & 1 deletion test/inductor/test_extension_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,10 @@ def fn(a, b, c):
metrics.reset()
opt_fn = torch.compile()(fn)
_, code = run_and_get_cpp_code(opt_fn, x, y, z)
if cpu_vec_isa.valid_vec_isa_list():
if (
cpu_vec_isa.valid_vec_isa_list()
and os.getenv("ATEN_CPU_CAPABILITY") != "default"
):
load_expr = "loadu"
else:
load_expr = " = in_ptr0[static_cast<long>(i0)];"
Expand Down
38 changes: 31 additions & 7 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
importlib.import_module("functorch")
importlib.import_module("filelock")

from torch._inductor import config, test_operators
from torch._inductor import config, cpu_vec_isa, test_operators
from torch._inductor.compile_fx import (
compile_fx,
compile_fx_inner,
Expand Down Expand Up @@ -1523,10 +1523,16 @@ def test(
pass # no device asserts in halide
elif self.device == "cpu":
_, code = run_and_get_cpp_code(fn_opt, *inps)
self.assertTrue((") ? (" in code or "blendv" in code) is has_wrapping)
self.assertTrue(("TORCH_CHECK" in code) is has_assert)
# Assert that we always vectorize the kernel regardless of wrapping / checks
self.assertTrue(("loadu" in code) is vectorize)
if (
cpu_vec_isa.valid_vec_isa_list()
and os.getenv("ATEN_CPU_CAPABILITY") != "default"
):
self.assertTrue(
(") ? (" in code or "blendv" in code) is has_wrapping
)
# Assert that we always vectorize the kernel regardless of wrapping / checks
self.assertTrue(("loadu" in code) is vectorize)
else:
code = run_and_get_triton_code(fn_opt, *inps)
self.assertTrue(("tl.where" in code) is has_wrapping)
Expand Down Expand Up @@ -1838,8 +1844,20 @@ def test_multilayer_var_lowp(self):
def fn(a):
return torch.var(a)

self.common(fn, (torch.rand((16, 16, 352, 352), dtype=torch.float16),))
self.common(fn, (torch.rand((14923), dtype=torch.float16),))
atol = None
rtol = None
if self.device == "cpu" and os.getenv("ATEN_CPU_CAPABILITY") == "default":
atol = 1e-3
rtol = 1e-3
self.common(
fn,
(torch.rand((16, 16, 352, 352), dtype=torch.float16),),
atol=atol,
rtol=rtol,
)
self.common(
fn, (torch.rand((14923), dtype=torch.float16),), atol=atol, rtol=rtol
)

def test_split_cumsum(self):
def fn(a):
Expand Down Expand Up @@ -10103,9 +10121,15 @@ def fn(query, scores, window_overlap):
if is_cpp_backend(self.device):
opt_fn = torch._dynamo.optimize("inductor")(fn)
_, code = run_and_get_cpp_code(opt_fn, *args)
num = (
2
if cpu_vec_isa.valid_vec_isa_list()
and os.getenv("ATEN_CPU_CAPABILITY") != "default"
else 1
)
FileCheck().check_count(
"static_cast<int64_t>(256)",
2,
num,
exactly=True,
).run(code)

Expand Down
Loading

0 comments on commit 6931c16

Please sign in to comment.