Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docker release matrix workflow #5065

Merged
merged 3 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions .github/workflows/generate_docker_release_matrix.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
name: Generates the docker release matrix

on:
workflow_call:
inputs:
channel:
description: "Channel to use (nightly, test, release, all)"
default: ""
type: string
test-infra-repository:
description: "Test infra repository to use"
default: "pytorch/test-infra"
type: string
test-infra-ref:
description: "Test infra reference to use"
default: "main"
type: string
outputs:
matrix:
description: "Generated build matrix"
value: ${{ jobs.generate.outputs.matrix }}

jobs:
generate:
outputs:
matrix: ${{ steps.generate.outputs.matrix }}
runs-on: ubuntu-latest
steps:
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Checkout test-infra repository
uses: actions/checkout@v3
with:
repository: ${{ inputs.test-infra-repository }}
ref: ${{ inputs.test-infra-ref }}
- uses: ./.github/actions/set-channel
- name: Generate docker release matrix
id: generate
env:
CHANNEL: ${{ inputs.channel != '' && inputs.channel || env.CHANNEL }}
run: |
set -eou pipefail
MATRIX_BLOB="$(python3 tools/scripts/generate_docker_release_matrix.py)"
echo "${MATRIX_BLOB}"
echo "matrix=${MATRIX_BLOB}" >> "${GITHUB_OUTPUT}"

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ inputs.package-type }}-${{ inputs.os }}-${{ inputs.test-infra-repository }}-${{ inputs.test-infra-ref }}
Fixed Show fixed Hide fixed
cancel-in-progress: true
5 changes: 5 additions & 0 deletions tools/scripts/generate_binary_build_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
"release": ["5.6", "5.7"],
}

CUDA_CUDDN_VERSIONS = {
"11.8": { "cuda": "11.8.0", "cudnn": "8" },
"12.1": { "cuda": "12.1.1", "cudnn": "8"},
}

PACKAGE_TYPES = ["wheel", "conda", "libtorch"]
PRE_CXX11_ABI = "pre-cxx11"
CXX11_ABI = "cxx11-abi"
Expand Down
58 changes: 58 additions & 0 deletions tools/scripts/generate_docker_release_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python3

"""Generates a matrix for docker releases through github actions

Will output a condensed version of the matrix. Will include fllowing:
* CUDA version short
* CUDA full version
* CUDNN version short
* Image type either runtime or devel
* Platform linux/arm64,linux/amd64

"""

import json
import os
import sys
import argparse
from typing import Dict, List

import generate_binary_build_matrix

DOCKER_IMAGE_TYPES = ["runtime", "devel"]


def generate_docker_matrix(channel: str) -> Dict[str, List[Dict[str, str]]]:

ret: List[Dict[str, str]] = []
for cuda in generate_binary_build_matrix.CUDA_ARCHES_DICT[channel]:
version = generate_binary_build_matrix.CUDA_CUDDN_VERSIONS[cuda]
for image in DOCKER_IMAGE_TYPES:
ret.append(
{
"cuda": cuda,
"cuda_full_version": version["cuda"],
"cudnn_version": version["cudnn"],
"image_type": image,
"platform": "linux/arm64,linux/amd64",
}
)
return {"include": ret}


def main(args) -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--channel",
help="Channel to use, default nightly",
type=str,
choices=["nightly", "test", "release", "all"],
default=os.getenv("CHANNEL", "nightly"),
)
options = parser.parse_args(args)

build_matrix = generate_docker_matrix(options.channel)
print(json.dumps(build_matrix))

if __name__ == "__main__":
main(sys.argv[1:])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks weird, I don't think you need to pass sys.argv to argparse like this, just calling options = parser.parse_args() without any arguments works.

Loading