Skip to content

Commit

Permalink
Extend vectorization with SVE(ARM) with Torch Compile (Inductor) (pyt…
Browse files Browse the repository at this point in the history
…orch#134672)

**Motivation**
Enable SVE vectorization with `torch.compile`
Extends PR: pytorch#119571

* This PR enables vectorization for codegen part using SVE-256 (vec length)
* The changes can be extended to other SVE vec lengths

I've done some comparisons against existing NEON implementation with SVE vectorization enabled route for `torch.compile`
Test results are for 8 cores on ARM Neoverse_V1

<img width="359" alt="Screenshot 2024-08-28 at 16 02 07" src="https://github.com/user-attachments/assets/6961fbea-8285-4ca3-b92e-934a2db50ee2">

It's worth mentioning, for standalone `SiLU op` there's a `~1.8x` speedup with `torch.compile`

Pull Request resolved: pytorch#134672
Approved by: https://github.com/jgong5, https://github.com/malfet
  • Loading branch information
aditew01 authored and pytorchmergebot committed Oct 10, 2024
1 parent 479bd1f commit 575f260
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 4 deletions.
8 changes: 8 additions & 0 deletions aten/src/ATen/cpu/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ bool init_amx() {
#endif
}

bool is_arm_sve_supported() {
#if !defined(__s390x__) && !defined(__powerpc__)
return cpuinfo_initialize() && cpuinfo_has_arm_sve();
#else
return false;
#endif
}

static uint32_t get_cache_size(int level) {
#if !defined(__s390x__) && !defined(__powerpc__)
if (!cpuinfo_initialize()) {
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ TORCH_API bool is_amx_tile_supported();
// Enable the system to use AMX instructions.
TORCH_API bool init_amx();

// Detect if CPU supports Arm(R) architecture SVE ISA
TORCH_API bool is_arm_sve_supported();

// Get the L1 cache size per core in Byte
TORCH_API uint32_t L1d_cache_size();

Expand Down
24 changes: 24 additions & 0 deletions aten/src/ATen/cpu/vec/functional_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,30 @@ struct VecReduceAllSIMD<float, Op> {
};
#endif // defined(__aarch64__)

#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && defined(CPU_CAPABILITY_SVE256)
template <typename Op>
struct VecReduceAllSIMD<float, Op> {
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {
using Vec = Vectorized<float>;
Vec v = acc_vec;
// 128-bit shuffle
svuint32_t ind = svdupq_n_u32(4, 5, 6, 7);
Vec v1 = svtbl_f32(v, ind);
v = vec_fun(v, v1);
// 64-bit shuffle
ind = svdupq_n_u32(2, 3, 0, 1);
v1 = svtbl_f32(v, ind);
v = vec_fun(v, v1);
// 32-bit shuffle
ind = svdupq_n_u32(1, 0, 2, 3);
v1 = svtbl_f32(v, ind);
v = vec_fun(v, v1);
return svlasta(svpfalse(), v);
}
};
#endif // defined(__aarch64__)


template <typename scalar_t, typename Op>
inline scalar_t vec_reduce_all(const Op& vec_fun, const Vectorized<scalar_t>& acc_vec) {
return VecReduceAllSIMD<scalar_t, Op>::apply(vec_fun, acc_vec);
Expand Down
1 change: 1 addition & 0 deletions torch/_C/_cpu.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ def _is_avx512_vnni_supported() -> _bool: ...
def _is_avx512_bf16_supported() -> _bool: ...
def _is_amx_tile_supported() -> _bool: ...
def _init_amx() -> _bool: ...
def _is_arm_sve_supported() -> _bool: ...
def _L1d_cache_size() -> _int: ...
def _L2_cache_size() -> _int: ...
2 changes: 2 additions & 0 deletions torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@
"torch._C._cpu._is_avx512_bf16_supported",
"torch._C._cpu._is_amx_tile_supported",
"torch._C._cpu._init_amx",
"torch._C._cpu._is_arm_sve_supported",
"torch._C._crash_if_aten_asan",
"torch._C._crash_if_csrc_asan",
"torch._C._crash_if_csrc_ubsan",
Expand Down Expand Up @@ -2445,6 +2446,7 @@
"torch._C._cpu._is_avx512_bf16_supported",
"torch._C._cpu._is_amx_tile_supported",
"torch.cpu._init_amx",
"torch._C._cpu._is_arm_sve_supported",
"torch.cpu.current_device",
"torch.cpu.current_stream",
"torch.cpu.device_count",
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/cpp_prefix.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <c10/util/TypeCast.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>

#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX)
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_SVE256)
#define INDUCTOR_USE_VECTOR_TYPES() 1
#else
#define INDUCTOR_USE_VECTOR_TYPES() 0
Expand Down
27 changes: 24 additions & 3 deletions torch/_inductor/cpu_vec_isa.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class VecISA:
# In fbcode however, we are using the same compiler for pytorch and for inductor codegen,
# making the runtime check unnecessary.
_avx_code = """
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX)
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_SVE)
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#endif
Expand Down Expand Up @@ -161,6 +161,24 @@ def __str__(self) -> str:
__hash__: Callable[[VecISA], Any] = VecISA.__hash__


@dataclasses.dataclass
class VecSVE(VecISA):
# this function can be repurposed for SVE with variable vec length
_bit_width = 256
_macro = [
"CPU_CAPABILITY_SVE",
"CPU_CAPABILITY_SVE256",
"AT_BUILD_ARM_VEC256_WITH_SLEEF",
]
_arch_flags = "-march=armv8-a+sve -msve-vector-bits=256"
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}

def __str__(self) -> str:
return "asimd"

__hash__: Callable[[VecISA], Any] = VecISA.__hash__


@dataclasses.dataclass
class VecAVX512(VecISA):
_bit_width = 512
Expand Down Expand Up @@ -306,7 +324,7 @@ def _check_and_append_supported_isa(


invalid_vec_isa = InvalidVecISA()
supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON()]
supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE()]


# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
Expand Down Expand Up @@ -338,7 +356,10 @@ def valid_vec_isa_list() -> List[VecISA]:
elif arch == "ppc64le":
isa_list.append(VecVSX())
elif arch == "aarch64":
isa_list.append(VecNEON())
if torch.cpu._is_arm_sve_supported():
isa_list.append(VecSVE())
else:
isa_list.append(VecNEON())
elif arch in ["x86_64", "AMD64"]:
"""
arch value is x86_64 on Linux, and the value is AMD64 on Windows.
Expand Down
5 changes: 5 additions & 0 deletions torch/cpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def _init_amx() -> bool:
return torch._C._cpu._init_amx()


def _is_arm_sve_supported() -> bool:
r"""Returns a bool indicating if CPU supports Arm SVE."""
return torch._C._cpu._is_arm_sve_supported()


def is_available() -> bool:
r"""Returns a bool indicating if CPU is currently available.
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/cpu/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ void initModule(PyObject* module) {
cpu.def("_is_avx512_bf16_supported", at::cpu::is_avx512_bf16_supported);
cpu.def("_is_amx_tile_supported", at::cpu::is_amx_tile_supported);
cpu.def("_init_amx", at::cpu::init_amx);
cpu.def("_is_arm_sve_supported", at::cpu::is_arm_sve_supported);
cpu.def("_L1d_cache_size", at::cpu::L1d_cache_size);
cpu.def("_L2_cache_size", at::cpu::L2_cache_size);
}
Expand Down

0 comments on commit 575f260

Please sign in to comment.