diff --git a/.github/workflows/generate_binary_build_matrix.yml b/.github/workflows/generate_binary_build_matrix.yml index 562d69220b..c9ba6e9efe 100644 --- a/.github/workflows/generate_binary_build_matrix.yml +++ b/.github/workflows/generate_binary_build_matrix.yml @@ -43,6 +43,14 @@ on: description: "Generate binary build matrix for a python only package (i.e. only one python version)" default: "disable" type: string + use_split_build: + description: | + [Experimental] Build a libtorch only wheel and build pytorch such that + are built from the libtorch wheel. + required: false + type: boolean + default: false + outputs: matrix: description: "Generated build matrix" @@ -79,6 +87,7 @@ jobs: # In cases when pipy binaries are not published yet. USE_ONLY_DL_PYTORCH_ORG: ${{ inputs.use-only-dl-pytorch-org }} BUILD_PYTHON_ONLY: ${{ inputs.build-python-only }} + USE_SPLIT_BUILD: ${{ inputs.use_split_build }} run: | set -eou pipefail MATRIX_BLOB="$(python3 tools/scripts/generate_binary_build_matrix.py)" diff --git a/tools/scripts/generate_binary_build_matrix.py b/tools/scripts/generate_binary_build_matrix.py index 39b712d0a0..a6fc82255b 100644 --- a/tools/scripts/generate_binary_build_matrix.py +++ b/tools/scripts/generate_binary_build_matrix.py @@ -287,8 +287,13 @@ def get_wheel_install_command( desired_cuda: str, python_version: str, use_only_dl_pytorch_org: bool, + use_split_build: bool = False, ) -> str: - + if use_split_build: + if (gpu_arch_version in CUDA_ARCHES) and (os == LINUX) and (channel == NIGHTLY): + return f"{WHL_INSTALL_BASE} {PACKAGES_TO_INSTALL_WHL} --index-url {get_base_download_url_for_repo('whl', channel, gpu_arch_type, desired_cuda)}_pypi_pkg" + else: + raise ValueError("Split build is not supported for this configuration. It is only supported for CUDA 11.8, 12.1, 12.4 on Linux nightly builds.") if channel == RELEASE and (not use_only_dl_pytorch_org) and ( (gpu_arch_version == "12.1" and os == LINUX) or ( @@ -315,6 +320,7 @@ def generate_conda_matrix( with_cpu: str, limit_pr_builds: bool, use_only_dl_pytorch_org: bool, + use_split_build: bool = False, ) -> List[Dict[str, str]]: ret: List[Dict[str, str]] = [] python_versions = list(PYTHON_ARCHES) @@ -370,6 +376,7 @@ def generate_libtorch_matrix( with_cpu: str, limit_pr_builds: bool, use_only_dl_pytorch_org: bool, + use_split_build: bool = False, abi_versions: Optional[List[str]] = None, arches: Optional[List[str]] = None, libtorch_variants: Optional[List[str]] = None, @@ -458,6 +465,7 @@ def generate_wheels_matrix( with_cpu: str, limit_pr_builds: bool, use_only_dl_pytorch_org: bool, + use_split_build: bool = False, arches: Optional[List[str]] = None, python_versions: Optional[List[str]] = None, ) -> List[Dict[str, str]]: @@ -511,8 +519,7 @@ def generate_wheels_matrix( ) desired_cuda = translate_desired_cuda(gpu_arch_type, gpu_arch_version) - ret.append( - { + entry = { "python_version": python_version, "gpu_arch_type": gpu_arch_type, "gpu_arch_version": gpu_arch_version, @@ -535,8 +542,17 @@ def generate_wheels_matrix( "channel": channel, "upload_to_base_bucket": upload_to_base_bucket, "stable_version": CURRENT_VERSION, + "use_split_build": False, } - ) + ret.append(entry) + if use_split_build: + entry = entry.copy() + entry["build_name"] = f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-split".replace( + ".", "_" + ) + entry["use_split_build"] = True + ret.append(entry) + return ret @@ -557,6 +573,7 @@ def generate_build_matrix( limit_pr_builds: str, use_only_dl_pytorch_org: str, build_python_only: str, + use_split_build: str, ) -> Dict[str, List[Dict[str, str]]]: includes = [] @@ -578,6 +595,7 @@ def generate_build_matrix( with_cpu, limit_pr_builds == "true", use_only_dl_pytorch_org == "true", + use_split_build == "true", ) ) @@ -657,6 +675,14 @@ def main(args: List[str]) -> None: default=os.getenv("BUILD_PYTHON_ONLY", ENABLE), ) + parser.add_argument( + "--use-split-build", + help="Use split build for wheel", + type=str, + choices=["true", "false"], + default=os.getenv("USE_SPLIT_BUILD", DISABLE), + ) + options = parser.parse_args(args) assert ( @@ -673,6 +699,7 @@ def main(args: List[str]) -> None: options.limit_pr_builds, options.use_only_dl_pytorch_org, options.build_python_only, + options.use_split_build, ) print(json.dumps(build_matrix))