Skip to content

Commit

Permalink
Option to Generate Matrix without ROCM
Browse files Browse the repository at this point in the history
  • Loading branch information
osalpekar committed Jul 7, 2023
1 parent 5ffa10a commit a492bfe
Showing 1 changed file with 25 additions and 17 deletions.
42 changes: 25 additions & 17 deletions tools/scripts/generate_binary_build_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def get_wheel_install_command(os: str, channel: str, gpu_arch_type: str, gpu_arc
whl_install_command = f"{WHL_INSTALL_BASE} --pre {PACKAGES_TO_INSTALL_WHL}" if channel == "nightly" else f"{WHL_INSTALL_BASE} {PACKAGES_TO_INSTALL_WHL}"
return f"{whl_install_command} --index-url {get_base_download_url_for_repo('whl', channel, gpu_arch_type, desired_cuda)}"

def generate_conda_matrix(os: str, channel: str, with_cuda: str, limit_pr_builds: bool) -> List[Dict[str, str]]:
def generate_conda_matrix(os: str, channel: str, with_cuda: str, with_rocm: str, limit_pr_builds: bool) -> List[Dict[str, str]]:
ret: List[Dict[str, str]] = []
arches = ["cpu"]
python_versions = list(mod.PYTHON_ARCHES)
Expand Down Expand Up @@ -247,6 +247,7 @@ def generate_libtorch_matrix(
os: str,
channel: str,
with_cuda: str,
with_rocm: str,
limit_pr_builds: str,
abi_versions: Optional[List[str]] = None,
arches: Optional[List[str]] = None,
Expand All @@ -265,9 +266,13 @@ def generate_libtorch_matrix(
if with_cuda == ENABLE:
if os == "linux":
arches += mod.CUDA_ARCHES
arches += mod.ROCM_ARCHES
elif os == "windows":
arches += mod.CUDA_ARCHES

if with_rocm == ENABLE:
if os == "linux":
arches += mod.ROCM_ARCHES


if abi_versions is None:
if os == "windows":
Expand Down Expand Up @@ -336,6 +341,7 @@ def generate_wheels_matrix(
os: str,
channel: str,
with_cuda: str,
with_rocm: str,
limit_pr_builds: bool,
arches: Optional[List[str]] = None,
python_versions: Optional[List[str]] = None,
Expand All @@ -358,10 +364,14 @@ def generate_wheels_matrix(
if with_cuda == ENABLE:
upload_to_base_bucket = "no"
if os == "linux":
arches += mod.CUDA_ARCHES + mod.ROCM_ARCHES
arches += mod.CUDA_ARCHES
elif os == "windows":
arches += mod.CUDA_ARCHES

if with_rocm == ENABLE:
if os == "linux":
arches += mod.ROCM_ARCHES

if limit_pr_builds:
python_versions = [ python_versions[0] ]

Expand Down Expand Up @@ -427,6 +437,13 @@ def main(args) -> None:
choices=[ENABLE, DISABLE],
default=os.getenv("WITH_CUDA", ENABLE),
)
parser.add_argument(
"--with-rocm",
help="Build with Rocm?",
type=str,
choices=[ENABLE, DISABLE],
default=os.getenv("WITH_ROCM", ENABLE),
)
# By default this is false for this script but expectation is that the caller
# workflow will default this to be true most of the time, where a pull
# request is synchronized and does not contain the label "ciflow/binaries/all"
Expand All @@ -452,20 +469,11 @@ def main(args) -> None:
for channel in channels:
for package in package_types:
initialize_globals(channel)
if package == "wheel":
includes.extend(
GENERATING_FUNCTIONS_BY_PACKAGE_TYPE[package](options.operating_system,
channel,
options.with_cuda,
options.limit_pr_builds == "true")
)
else:
includes.extend(
GENERATING_FUNCTIONS_BY_PACKAGE_TYPE[package](options.operating_system,
channel,
options.with_cuda,
options.limit_pr_builds == "true")
)
GENERATING_FUNCTIONS_BY_PACKAGE_TYPE[package](options.operating_system,
channel,
options.with_cuda,
options.with_rocm,
options.limit_pr_builds == "true")


print(json.dumps({"include": includes}))
Expand Down

0 comments on commit a492bfe

Please sign in to comment.