Skip to content

Commit

Permalink
[aarch64] add cuda aarch64 torchvision and torchaudio (#5387)
Browse files Browse the repository at this point in the history
add cuda aarch64 build for torchvision and torchaudio
  • Loading branch information
tinglvv authored Jul 3, 2024
1 parent 4432e2c commit b623e13
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions tools/scripts/generate_binary_build_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
# Accelerator architectures
CPU = "cpu"
CPU_AARCH64 = "cpu-aarch64"
CUDA_AARCH64 = "cuda-aarch64"
CUDA = "cuda"
ROCM = "rocm"

Expand All @@ -80,6 +81,7 @@
LINUX_GPU_RUNNER = "linux.g5.4xlarge.nvidia.gpu"
LINUX_CPU_RUNNER = "linux.2xlarge"
LINUX_AARCH64_RUNNER = "linux.arm64.2xlarge"
LINUX_AARCH64_GPU_RUNNER = "linux.arm64.m7g.4xlarge"
WIN_GPU_RUNNER = "windows.8xlarge.nvidia.gpu"
WIN_CPU_RUNNER = "windows.4xlarge"
MACOS_M1_RUNNER = "macos-m1-stable"
Expand All @@ -103,6 +105,8 @@ def arch_type(arch_version: str) -> str:
return ROCM
elif arch_version == CPU_AARCH64:
return CPU_AARCH64
elif arch_version == CUDA_AARCH64:
return CUDA_AARCH64
else: # arch_version should always be CPU in this case
return CPU

Expand All @@ -114,7 +118,10 @@ def validation_runner(arch_type: str, os: str) -> str:
else:
return LINUX_CPU_RUNNER
elif os == LINUX_AARCH64:
return LINUX_AARCH64_RUNNER
if arch_type == CUDA_AARCH64:
return LINUX_AARCH64_GPU_RUNNER
else:
return LINUX_AARCH64_RUNNER
elif os == WINDOWS:
if arch_type == CUDA:
return WIN_GPU_RUNNER
Expand Down Expand Up @@ -154,6 +161,7 @@ def initialize_globals(channel: str, build_python_only: bool) -> None:
},
CPU: "pytorch/manylinux-builder:cpu",
CPU_AARCH64: "pytorch/manylinuxaarch64-builder:cpu-aarch64",
CUDA_AARCH64: "pytorch/manylinuxaarch64-builder:cuda12.4",
}
CONDA_CONTAINER_IMAGES = {
**{
Expand Down Expand Up @@ -188,6 +196,7 @@ def translate_desired_cuda(gpu_arch_type: str, gpu_arch_version: str) -> str:
return {
CPU: "cpu",
CPU_AARCH64: CPU,
CUDA_AARCH64: "cu124",
CUDA: f"cu{gpu_arch_version.replace('.', '')}",
ROCM: f"rocm{gpu_arch_version}",
}.get(gpu_arch_type, gpu_arch_version)
Expand Down Expand Up @@ -490,7 +499,7 @@ def generate_wheels_matrix(
if os == LINUX_AARCH64:
# Only want the one arch as the CPU type is different and
# uses different build/test scripts
arches = [CPU_AARCH64]
arches = [CPU_AARCH64, CUDA_AARCH64]

if with_cuda == ENABLE:
upload_to_base_bucket = "no"
Expand Down

0 comments on commit b623e13

Please sign in to comment.