Skip to content

Commit

Permalink
update script with dynamic torch-ecosystem versions determinations (#318
Browse files Browse the repository at this point in the history
)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] authored Oct 18, 2024
1 parent 24081b0 commit 1d0e380
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 37 deletions.
56 changes: 56 additions & 0 deletions .github/workflows/ci-scripts.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
name: Test scripts

on:
push:
branches: [main, "release/*"]
pull_request:
branches: [main, "release/*"]

defaults:
run:
shell: bash

jobs:
test-scripts:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: ["ubuntu-22.04", "macos-12", "windows-2022"]
python-version: ["3.8", "3.12"]
timeout-minutes: 35
steps:
- name: Checkout 🛎
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Python 🐍 ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"

- name: Install dependencies
timeout-minutes: 5
run: |
pip install -r requirements/_tests.txt
pip --version
pip list
- name: test Scripts
working-directory: ./scripts
run: pytest . -v

scripts-guardian:
runs-on: ubuntu-latest
needs: test-scripts
if: always()
steps:
- run: echo "${{ needs.test-scripts.result }}"
- name: failing...
if: needs.test-scripts.result == 'failure'
run: exit 1
- name: cancelled or skipped...
if: contains(fromJSON('["cancelled", "skipped"]'), needs.test-scripts.result)
timeout-minutes: 1
run: sleep 90
159 changes: 122 additions & 37 deletions scripts/adjust-torch-versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,56 +9,141 @@
import sys
from typing import Dict, List, Optional

VERSIONS = [
{"torch": "2.6.0", "torchvision": "0.21.0", "torchtext": "0.18.0", "torchaudio": "2.6.0"}, # nightly
{"torch": "2.5.0", "torchvision": "0.20.0", "torchtext": "0.18.0", "torchaudio": "2.5.0"}, # stable
{"torch": "2.4.1", "torchvision": "0.19.1", "torchtext": "0.18.0", "torchaudio": "2.4.1"},
{"torch": "2.4.0", "torchvision": "0.19.0", "torchtext": "0.18.0", "torchaudio": "2.4.0"},
{"torch": "2.3.1", "torchvision": "0.18.1", "torchtext": "0.18.0", "torchaudio": "2.3.1"},
{"torch": "2.3.0", "torchvision": "0.18.0", "torchtext": "0.18.0", "torchaudio": "2.3.0"},
{"torch": "2.2.2", "torchvision": "0.17.2", "torchtext": "0.17.2", "torchaudio": "2.2.2"},
{"torch": "2.2.1", "torchvision": "0.17.1", "torchtext": "0.17.1", "torchaudio": "2.2.1"},
{"torch": "2.2.0", "torchvision": "0.17.0", "torchtext": "0.17.0", "torchaudio": "2.2.0"},
{"torch": "2.1.2", "torchvision": "0.16.2", "torchtext": "0.16.2", "torchaudio": "2.1.2"},
{"torch": "2.1.1", "torchvision": "0.16.1", "torchtext": "0.16.1", "torchaudio": "2.1.1"},
{"torch": "2.1.0", "torchvision": "0.16.0", "torchtext": "0.16.0", "torchaudio": "2.1.0"},
{"torch": "2.0.1", "torchvision": "0.15.2", "torchtext": "0.15.2", "torchaudio": "2.0.2"},
{"torch": "2.0.0", "torchvision": "0.15.1", "torchtext": "0.15.1", "torchaudio": "2.0.1"},
{"torch": "1.14.0", "torchvision": "0.15.0", "torchtext": "0.15.0", "torchaudio": "0.14.0"}, # nightly / shifted
{"torch": "1.13.1", "torchvision": "0.14.1", "torchtext": "0.14.1", "torchaudio": "0.13.1"},
{"torch": "1.13.0", "torchvision": "0.14.0", "torchtext": "0.14.0", "torchaudio": "0.13.0"},
{"torch": "1.12.1", "torchvision": "0.13.1", "torchtext": "0.13.1", "torchaudio": "0.12.1"},
{"torch": "1.12.0", "torchvision": "0.13.0", "torchtext": "0.13.0", "torchaudio": "0.12.0"},
{"torch": "1.11.0", "torchvision": "0.12.0", "torchtext": "0.12.0", "torchaudio": "0.11.0"},
{"torch": "1.10.2", "torchvision": "0.11.3", "torchtext": "0.11.2", "torchaudio": "0.10.2"},
{"torch": "1.10.1", "torchvision": "0.11.2", "torchtext": "0.11.1", "torchaudio": "0.10.1"},
{"torch": "1.10.0", "torchvision": "0.11.1", "torchtext": "0.11.0", "torchaudio": "0.10.0"},
{"torch": "1.9.1", "torchvision": "0.10.1", "torchtext": "0.10.1", "torchaudio": "0.9.1"},
{"torch": "1.9.0", "torchvision": "0.10.0", "torchtext": "0.10.0", "torchaudio": "0.9.0"},
{"torch": "1.8.2", "torchvision": "0.9.1", "torchtext": "0.9.1", "torchaudio": "0.8.1"},
{"torch": "1.8.1", "torchvision": "0.9.1", "torchtext": "0.9.1", "torchaudio": "0.8.1"},
{"torch": "1.8.0", "torchvision": "0.9.0", "torchtext": "0.9.0", "torchaudio": "0.8.0"},
]

def _determine_torchaudio(torch_version: str) -> str:
"""Determine the torchaudio version based on the torch version.
>>> _determine_torchaudio("1.9.0")
'0.9.0'
>>> _determine_torchaudio("2.4.1")
'2.4.1'
>>> _determine_torchaudio("1.8.2")
'0.9.1'
"""
_version_exceptions = {
"1.8.2": "0.9.1",
}
# drop all except semantic version
torch_ver = re.search(r"([\.\d]+)", torch_version).groups()[0]
if torch_ver in _version_exceptions:
return _version_exceptions[torch_ver]
ver_major, ver_minor, ver_bugfix = map(int, torch_ver.split("."))
ta_ver_array = [ver_major, ver_minor, ver_bugfix]
if ver_major == 1:
ta_ver_array[0] = 0
ta_ver_array[2] = ver_bugfix
return ".".join(map(str, ta_ver_array))


def _determine_torchtext(torch_version: str) -> str:
"""Determine the torchtext version based on the torch version.
>>> _determine_torchtext("1.9.0")
'0.10.0'
>>> _determine_torchtext("2.4.1")
'0.18.0'
>>> _determine_torchtext("1.8.2")
'0.9.1'
"""
_version_exceptions = {
"2.0.1": "0.15.2",
"2.0.0": "0.15.1",
"1.8.2": "0.9.1",
}
# drop all except semantic version
torch_ver = re.search(r"([\.\d]+)", torch_version).groups()[0]
if torch_ver in _version_exceptions:
return _version_exceptions[torch_ver]
ver_major, ver_minor, ver_bugfix = map(int, torch_ver.split("."))
tt_ver_array = [0, 0, 0]
if ver_major == 1:
tt_ver_array[1] = ver_minor + 1
tt_ver_array[2] = ver_bugfix
elif ver_major == 2:
if ver_minor >= 3:
tt_ver_array[1] = 18
else:
tt_ver_array[1] = ver_minor + 15
tt_ver_array[2] = ver_bugfix
else:
raise ValueError(f"Invalid torch version: {torch_version}")
return ".".join(map(str, tt_ver_array))


def _determine_torchvision(torch_version: str) -> str:
"""Determine the torchvision version based on the torch version.
>>> _determine_torchvision("1.9.0")
'0.10.0'
>>> _determine_torchvision("2.4.1")
'0.19.1'
>>> _determine_torchvision("2.0.1")
'0.15.2'
"""
_version_exceptions = {
"2.0.1": "0.15.2",
"2.0.0": "0.15.1",
"1.10.2": "0.11.3",
"1.10.1": "0.11.2",
"1.10.0": "0.11.1",
"1.8.2": "0.9.1",
}
# drop all except semantic version
torch_ver = re.search(r"([\.\d]+)", torch_version).groups()[0]
if torch_ver in _version_exceptions:
return _version_exceptions[torch_ver]
ver_major, ver_minor, ver_bugfix = map(int, torch_ver.split("."))
tv_ver_array = [0, 0, 0]
if ver_major == 1:
tv_ver_array[1] = ver_minor + 1
elif ver_major == 2:
tv_ver_array[1] = ver_minor + 15
else:
raise ValueError(f"Invalid torch version: {torch_version}")
tv_ver_array[2] = ver_bugfix
return ".".join(map(str, tv_ver_array))


def find_latest(ver: str) -> Dict[str, str]:
"""Find the latest version."""
"""Find the latest version.
>>> from pprint import pprint
>>> pprint(find_latest("2.1.0"))
{'torch': '2.1.0',
'torchaudio': '2.1.0',
'torchtext': '0.16.0',
'torchvision': '0.16.0'}
"""
# drop all except semantic version
ver = re.search(r"([\.\d]+)", ver).groups()[0]
# in case there remaining dot at the end - e.g "1.9.0.dev20210504"
ver = ver[:-1] if ver[-1] == "." else ver
logging.debug(f"finding ecosystem versions for: {ver}")

# find first match
for option in VERSIONS:
if option["torch"].startswith(ver):
return option

raise ValueError(f"Missing {ver} in {VERSIONS}")
return {
"torch": ver,
"torchvision": _determine_torchvision(ver),
"torchtext": _determine_torchtext(ver),
"torchaudio": _determine_torchaudio(ver),
}


def adjust(requires: List[str], pytorch_version: Optional[str] = None) -> List[str]:
"""Adjust the versions to be paired within pytorch ecosystem."""
"""Adjust the versions to be paired within pytorch ecosystem.
>>> from pprint import pprint
>>> pprint(adjust(["torch>=1.9.0", "torchvision>=0.10.0", "torchtext>=0.10.0", "torchaudio>=0.9.0"], "2.1.0"))
['torch==2.1.0',
'torchvision==0.16.0',
'torchtext==0.16.0',
'torchaudio==2.1.0']
"""
if not pytorch_version:
import torch

Expand Down

0 comments on commit 1d0e380

Please sign in to comment.